| """ Generic Unification algorithm for expression trees with lists of children | |
| This implementation is a direct translation of | |
| Artificial Intelligence: A Modern Approach by Stuart Russel and Peter Norvig | |
| Second edition, section 9.2, page 276 | |
| It is modified in the following ways: | |
| 1. We allow associative and commutative Compound expressions. This results in | |
| combinatorial blowup. | |
| 2. We explore the tree lazily. | |
| 3. We provide generic interfaces to symbolic algebra libraries in Python. | |
| A more traditional version can be found here | |
| http://aima.cs.berkeley.edu/python/logic.html | |
| """ | |
| from sympy.utilities.iterables import kbins | |
| class Compound: | |
| """ A little class to represent an interior node in the tree | |
| This is analogous to SymPy.Basic for non-Atoms | |
| """ | |
| def __init__(self, op, args): | |
| self.op = op | |
| self.args = args | |
| def __eq__(self, other): | |
| return (type(self) is type(other) and self.op == other.op and | |
| self.args == other.args) | |
| def __hash__(self): | |
| return hash((type(self), self.op, self.args)) | |
| def __str__(self): | |
| return "%s[%s]" % (str(self.op), ', '.join(map(str, self.args))) | |
| class Variable: | |
| """ A Wild token """ | |
| def __init__(self, arg): | |
| self.arg = arg | |
| def __eq__(self, other): | |
| return type(self) is type(other) and self.arg == other.arg | |
| def __hash__(self): | |
| return hash((type(self), self.arg)) | |
| def __str__(self): | |
| return "Variable(%s)" % str(self.arg) | |
| class CondVariable: | |
| """ A wild token that matches conditionally. | |
| arg - a wild token. | |
| valid - an additional constraining function on a match. | |
| """ | |
| def __init__(self, arg, valid): | |
| self.arg = arg | |
| self.valid = valid | |
| def __eq__(self, other): | |
| return (type(self) is type(other) and | |
| self.arg == other.arg and | |
| self.valid == other.valid) | |
| def __hash__(self): | |
| return hash((type(self), self.arg, self.valid)) | |
| def __str__(self): | |
| return "CondVariable(%s)" % str(self.arg) | |
| def unify(x, y, s=None, **fns): | |
| """ Unify two expressions. | |
| Parameters | |
| ========== | |
| x, y - expression trees containing leaves, Compounds and Variables. | |
| s - a mapping of variables to subtrees. | |
| Returns | |
| ======= | |
| lazy sequence of mappings {Variable: subtree} | |
| Examples | |
| ======== | |
| >>> from sympy.unify.core import unify, Compound, Variable | |
| >>> expr = Compound("Add", ("x", "y")) | |
| >>> pattern = Compound("Add", ("x", Variable("a"))) | |
| >>> next(unify(expr, pattern, {})) | |
| {Variable(a): 'y'} | |
| """ | |
| s = s or {} | |
| if x == y: | |
| yield s | |
| elif isinstance(x, (Variable, CondVariable)): | |
| yield from unify_var(x, y, s, **fns) | |
| elif isinstance(y, (Variable, CondVariable)): | |
| yield from unify_var(y, x, s, **fns) | |
| elif isinstance(x, Compound) and isinstance(y, Compound): | |
| is_commutative = fns.get('is_commutative', lambda x: False) | |
| is_associative = fns.get('is_associative', lambda x: False) | |
| for sop in unify(x.op, y.op, s, **fns): | |
| if is_associative(x) and is_associative(y): | |
| a, b = (x, y) if len(x.args) < len(y.args) else (y, x) | |
| if is_commutative(x) and is_commutative(y): | |
| combs = allcombinations(a.args, b.args, 'commutative') | |
| else: | |
| combs = allcombinations(a.args, b.args, 'associative') | |
| for aaargs, bbargs in combs: | |
| aa = [unpack(Compound(a.op, arg)) for arg in aaargs] | |
| bb = [unpack(Compound(b.op, arg)) for arg in bbargs] | |
| yield from unify(aa, bb, sop, **fns) | |
| elif len(x.args) == len(y.args): | |
| yield from unify(x.args, y.args, sop, **fns) | |
| elif is_args(x) and is_args(y) and len(x) == len(y): | |
| if len(x) == 0: | |
| yield s | |
| else: | |
| for shead in unify(x[0], y[0], s, **fns): | |
| yield from unify(x[1:], y[1:], shead, **fns) | |
| def unify_var(var, x, s, **fns): | |
| if var in s: | |
| yield from unify(s[var], x, s, **fns) | |
| elif occur_check(var, x): | |
| pass | |
| elif isinstance(var, CondVariable) and var.valid(x): | |
| yield assoc(s, var, x) | |
| elif isinstance(var, Variable): | |
| yield assoc(s, var, x) | |
| def occur_check(var, x): | |
| """ var occurs in subtree owned by x? """ | |
| if var == x: | |
| return True | |
| elif isinstance(x, Compound): | |
| return occur_check(var, x.args) | |
| elif is_args(x): | |
| if any(occur_check(var, xi) for xi in x): return True | |
| return False | |
| def assoc(d, key, val): | |
| """ Return copy of d with key associated to val """ | |
| d = d.copy() | |
| d[key] = val | |
| return d | |
| def is_args(x): | |
| """ Is x a traditional iterable? """ | |
| return type(x) in (tuple, list, set) | |
| def unpack(x): | |
| if isinstance(x, Compound) and len(x.args) == 1: | |
| return x.args[0] | |
| else: | |
| return x | |
| def allcombinations(A, B, ordered): | |
| """ | |
| Restructure A and B to have the same number of elements. | |
| Parameters | |
| ========== | |
| ordered must be either 'commutative' or 'associative'. | |
| A and B can be rearranged so that the larger of the two lists is | |
| reorganized into smaller sublists. | |
| Examples | |
| ======== | |
| >>> from sympy.unify.core import allcombinations | |
| >>> for x in allcombinations((1, 2, 3), (5, 6), 'associative'): print(x) | |
| (((1,), (2, 3)), ((5,), (6,))) | |
| (((1, 2), (3,)), ((5,), (6,))) | |
| >>> for x in allcombinations((1, 2, 3), (5, 6), 'commutative'): print(x) | |
| (((1,), (2, 3)), ((5,), (6,))) | |
| (((1, 2), (3,)), ((5,), (6,))) | |
| (((1,), (3, 2)), ((5,), (6,))) | |
| (((1, 3), (2,)), ((5,), (6,))) | |
| (((2,), (1, 3)), ((5,), (6,))) | |
| (((2, 1), (3,)), ((5,), (6,))) | |
| (((2,), (3, 1)), ((5,), (6,))) | |
| (((2, 3), (1,)), ((5,), (6,))) | |
| (((3,), (1, 2)), ((5,), (6,))) | |
| (((3, 1), (2,)), ((5,), (6,))) | |
| (((3,), (2, 1)), ((5,), (6,))) | |
| (((3, 2), (1,)), ((5,), (6,))) | |
| """ | |
| if ordered == "commutative": | |
| ordered = 11 | |
| if ordered == "associative": | |
| ordered = None | |
| sm, bg = (A, B) if len(A) < len(B) else (B, A) | |
| for part in kbins(list(range(len(bg))), len(sm), ordered=ordered): | |
| if bg == B: | |
| yield tuple((a,) for a in A), partition(B, part) | |
| else: | |
| yield partition(A, part), tuple((b,) for b in B) | |
| def partition(it, part): | |
| """ Partition a tuple/list into pieces defined by indices. | |
| Examples | |
| ======== | |
| >>> from sympy.unify.core import partition | |
| >>> partition((10, 20, 30, 40), [[0, 1, 2], [3]]) | |
| ((10, 20, 30), (40,)) | |
| """ | |
| return type(it)([index(it, ind) for ind in part]) | |
| def index(it, ind): | |
| """ Fancy indexing into an indexable iterable (tuple, list). | |
| Examples | |
| ======== | |
| >>> from sympy.unify.core import index | |
| >>> index([10, 20, 30], (1, 2, 0)) | |
| [20, 30, 10] | |
| """ | |
| return type(it)([it[i] for i in ind]) | |