Spaces:
Runtime error
Runtime error
| # Copyright 2023 DeepMind Technologies Limited | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Implementing Algebraic Reasoning (AR).""" | |
| from collections import defaultdict # pylint: disable=g-importing-member | |
| from fractions import Fraction as frac # pylint: disable=g-importing-member | |
| from typing import Any, Generator | |
| import geometry as gm | |
| import numpy as np | |
| import problem as pr | |
| from scipy import optimize | |
| class InfQuotientError(Exception): | |
| pass | |
| def _gcd(x: int, y: int) -> int: | |
| while y: | |
| x, y = y, x % y | |
| return x | |
| def simplify(n: int, d: int) -> tuple[int, int]: | |
| g = _gcd(n, d) | |
| return (n // g, d // g) | |
| # maximum denominator for a fraction. | |
| MAX_DENOMINATOR = 1000000 | |
| # tolerance for fraction approximation | |
| TOL = 1e-15 | |
| def get_quotient(v: float) -> tuple[int, int]: | |
| n = v | |
| d = 1 | |
| while abs(n - round(n)) > TOL: | |
| d += 1 | |
| n += v | |
| if d > MAX_DENOMINATOR: | |
| e = InfQuotientError(v) | |
| raise e | |
| n = int(round(n)) | |
| return simplify(n, d) | |
| def fix_v(v: float) -> float: | |
| n, d = get_quotient(v) | |
| return n / d | |
| def fix(e: dict[str, float]) -> dict[str, float]: | |
| return {k: fix_v(v) for k, v in e.items()} | |
| def frac_string(f: frac) -> str: | |
| n, d = get_quotient(f) | |
| return f'{n}/{d}' | |
| def hashed(e: dict[str, float]) -> tuple[tuple[str, float], ...]: | |
| return tuple(sorted(list(e.items()))) | |
| def is_zero(e: dict[str, float]) -> bool: | |
| return len(strip(e)) == 0 # pylint: disable=g-explicit-length-test | |
| def strip(e: dict[str, float]) -> dict[str, float]: | |
| return {v: c for v, c in e.items() if c != 0} | |
| def plus(e1: dict[str, float], e2: dict[str, float]) -> dict[str, float]: | |
| e = dict(e1) | |
| for v, c in e2.items(): | |
| if v in e: | |
| e[v] += c | |
| else: | |
| e[v] = c | |
| return strip(e) | |
| def plus_all(*es: list[dict[str, float]]) -> dict[str, float]: | |
| result = {} | |
| for e in es: | |
| result = plus(result, e) | |
| return result | |
| def mult(e: dict[str, float], m: float) -> dict[str, float]: | |
| return {v: m * c for v, c in e.items()} | |
| def minus(e1: dict[str, float], e2: dict[str, float]) -> dict[str, float]: | |
| return plus(e1, mult(e2, -1)) | |
| def div(e1: dict[str, float], e2: dict[str, float]) -> float: | |
| """Divide e1 by e2.""" | |
| e1 = strip(e1) | |
| e2 = strip(e2) | |
| if set(e1.keys()) != set(e2.keys()): | |
| return None | |
| n, d = None, None | |
| for v, c1 in e1.items(): | |
| c2 = e2[v] # we want c1/c2 = n/d => c1*d=c2*n | |
| if n is not None and c1 * d != c2 * n: | |
| return None | |
| n, d = c1, c2 | |
| return frac(n) / frac(d) | |
| def recon(e: dict[str, float], const: str) -> tuple[str, dict[str, float]]: | |
| """Reconcile one variable in the expression e=0, given const.""" | |
| e = strip(e) | |
| if len(e) == 0: # pylint: disable=g-explicit-length-test | |
| return None | |
| v0 = None | |
| for v in e: | |
| if v != const: | |
| v0 = v | |
| break | |
| if v0 is None: | |
| return v0 | |
| c0 = e.pop(v0) | |
| return v0, {v: -c / c0 for v, c in e.items()} | |
| def replace( | |
| e: dict[str, float], v0: str, e0: dict[str, float] | |
| ) -> dict[str, float]: | |
| if v0 not in e: | |
| return e | |
| e = dict(e) | |
| m = e.pop(v0) | |
| return plus(e, mult(e0, m)) | |
| def comb2(elems: list[Any]) -> Generator[tuple[Any, Any], None, None]: | |
| if len(elems) < 1: | |
| return | |
| for i, e1 in enumerate(elems[:-1]): | |
| for e2 in elems[i + 1 :]: | |
| yield e1, e2 | |
| def perm2(elems: list[Any]) -> Generator[tuple[Any, Any], None, None]: | |
| for e1, e2 in comb2(elems): | |
| yield e1, e2 | |
| yield e2, e1 | |
| def chain2(elems: list[Any]) -> Generator[tuple[Any, Any], None, None]: | |
| if len(elems) < 2: | |
| return | |
| for i, e1 in enumerate(elems[:-1]): | |
| yield e1, elems[i + 1] | |
| def update_groups( | |
| groups1: list[Any], groups2: list[Any] | |
| ) -> tuple[list[Any], list[tuple[Any, Any]], list[list[Any]]]: | |
| """Update groups of equivalent elements. | |
| Given groups1 = [set1, set2, set3, ..] | |
| where all elems within each set_i is defined to be "equivalent" to each other. | |
| (but not across the sets) | |
| Incoming groups2 = [set1, set2, ...] similar to set1 - it is the | |
| additional equivalent information on elements in groups1. | |
| Return the new updated groups1 and the set of links | |
| that make it that way. | |
| Example: | |
| groups1 = [{1, 2}, {3, 4, 5}, {6, 7}] | |
| groups2 = [{2, 3, 8}, {9, 10, 11}] | |
| => new groups1 and links: | |
| groups1 = [{1, 2, 3, 4, 5, 8}, {6, 7}, {9, 10, 11}] | |
| links = (2, 3), (3, 8), (9, 10), (10, 11) | |
| Explain: since groups2 says 2 and 3 are equivalent (with {2, 3, 8}), | |
| then {1, 2} and {3, 4, 5} in groups1 will be merged, | |
| because 2 and 3 each belong to those 2 groups. | |
| Additionally 8 also belong to this same group. | |
| {3, 4, 5} is left alone, while {9, 10, 11} is a completely new set. | |
| The links to make this all happens is: | |
| (2, 3): to merge {1, 2} and {3, 4, 5} | |
| (3, 8): to link 8 into the merged({1, 2, 3, 4, 5}) | |
| (9, 10) and (10, 11): to make the new group {9, 10, 11} | |
| Args: | |
| groups1: a list of sets. | |
| groups2: a list of sets. | |
| Returns: | |
| groups1, links, history: result of the update. | |
| """ | |
| history = [] | |
| links = [] | |
| for g2 in groups2: | |
| joins = [None] * len(groups1) # mark which one in groups1 is merged | |
| merged_g1 = set() # merge them into this. | |
| old = None # any elem in g2 that belong to any set in groups1 (old) | |
| new = [] # all elem in g2 that is new | |
| for e in g2: | |
| found = False | |
| for i, g1 in enumerate(groups1): | |
| if e not in g1: | |
| continue | |
| found = True | |
| if joins[i]: | |
| continue | |
| joins[i] = True | |
| merged_g1.update(g1) | |
| if old is not None: | |
| links.append((old, e)) # link to make merging happen. | |
| old = e | |
| if not found: # e is new! | |
| new.append(e) | |
| # now chain elems in new together. | |
| if old is not None and new: | |
| links.append((old, new[0])) | |
| merged_g1.update(new) | |
| links += chain2(new) | |
| new_groups1 = [] | |
| if merged_g1: # put the merged_g1 in first | |
| new_groups1.append(merged_g1) | |
| # put the remaining (unjoined) groups in | |
| new_groups1 += [g1 for j, g1 in zip(joins, groups1) if not j] | |
| if old is None and new: | |
| new_groups1 += [set(new)] | |
| groups1 = new_groups1 | |
| history.append(groups1) | |
| return groups1, links, history | |
| class Table: | |
| """The coefficient matrix.""" | |
| def __init__(self, const: str = '1'): | |
| self.const = const | |
| self.v2e = {} | |
| self.add_free(const) # the table {var: expression} | |
| # to cache what is already derived/inputted | |
| self.eqs = set() | |
| self.groups = [] # groups of equal pairs. | |
| # for why (linprog) | |
| self.c = [] | |
| self.v2i = {} # v -> index of row in A. | |
| self.deps = [] # equal number of columns. | |
| self.A = np.zeros([0, 0]) # pylint: disable=invalid-name | |
| self.do_why = True | |
| def add_free(self, v: str) -> None: | |
| self.v2e[v] = {v: frac(1)} | |
| def replace(self, v0: str, e0: dict[str, float]) -> None: | |
| for v, e in list(self.v2e.items()): | |
| self.v2e[v] = replace(e, v0, e0) | |
| def add_expr(self, vc: list[tuple[str, float]]) -> bool: | |
| """Add a new equality, represented by the list of tuples vc=[(v, c), ..].""" | |
| result = {} | |
| free = [] | |
| for v, c in vc: | |
| c = frac(c) | |
| if v in self.v2e: | |
| result = plus(result, mult(self.v2e[v], c)) | |
| else: | |
| free += [(v, c)] | |
| if free == []: # pylint: disable=g-explicit-bool-comparison | |
| if is_zero(self.modulo(result)): | |
| return False | |
| result = recon(result, self.const) | |
| if result is None: | |
| return False | |
| v, e = result | |
| self.replace(v, e) | |
| elif len(free) == 1: | |
| v, m = free[0] | |
| self.v2e[v] = mult(result, frac(-1, m)) | |
| else: | |
| dependent_v = None | |
| for v, m in free: | |
| if dependent_v is None and v != self.const: | |
| dependent_v = (v, m) | |
| continue | |
| self.add_free(v) | |
| result = plus(result, {v: m}) | |
| v, m = dependent_v | |
| self.v2e[v] = mult(result, frac(-1, m)) | |
| return True | |
| def register(self, vc: list[tuple[str, float]], dep: pr.Dependency) -> None: | |
| """Register a new equality vc=[(v, c), ..] with traceback dependency dep.""" | |
| result = plus_all(*[{v: c} for v, c in vc]) | |
| if is_zero(result): | |
| return | |
| vs, _ = zip(*vc) | |
| for v in vs: | |
| if v not in self.v2i: | |
| self.v2i[v] = len(self.v2i) | |
| (m, n), l = self.A.shape, len(self.v2i) | |
| if l > m: | |
| self.A = np.concatenate([self.A, np.zeros([l - m, n])], 0) | |
| new_column = np.zeros([len(self.v2i), 2]) # N, 2 | |
| for v, c in vc: | |
| new_column[self.v2i[v], 0] += float(c) | |
| new_column[self.v2i[v], 1] -= float(c) | |
| self.A = np.concatenate([self.A, new_column], 1) | |
| self.c += [1.0, -1.0] | |
| self.deps += [dep] | |
| def register2( | |
| self, a: str, b: str, m: float, n: float, dep: pr.Dependency | |
| ) -> None: | |
| self.register([(a, m), (b, -n)], dep) | |
| def register3(self, a: str, b: str, f: float, dep: pr.Dependency) -> None: | |
| self.register([(a, 1), (b, -1), (self.const, -f)], dep) | |
| def register4( | |
| self, a: str, b: str, c: str, d: str, dep: pr.Dependency | |
| ) -> None: | |
| self.register([(a, 1), (b, -1), (c, -1), (d, 1)], dep) | |
| def why(self, e: dict[str, float]) -> list[Any]: | |
| """AR traceback == MILP.""" | |
| if not self.do_why: | |
| return [] | |
| # why expr == 0? | |
| # Solve min(c^Tx) s.t. A_eq * x = b_eq, x >= 0 | |
| e = strip(e) | |
| if not e: | |
| return [] | |
| b_eq = [0] * len(self.v2i) | |
| for v, c in e.items(): | |
| b_eq[self.v2i[v]] += float(c) | |
| try: | |
| x = optimize.linprog(c=self.c, A_eq=self.A, b_eq=b_eq, method='highs')[ | |
| 'x' | |
| ] | |
| except: # pylint: disable=bare-except | |
| x = optimize.linprog( | |
| c=self.c, | |
| A_eq=self.A, | |
| b_eq=b_eq, | |
| )['x'] | |
| deps = [] | |
| for i, dep in enumerate(self.deps): | |
| if x[2 * i] > 1e-12 or x[2 * i + 1] > 1e-12: | |
| if dep not in deps: | |
| deps.append(dep) | |
| return deps | |
| def record_eq(self, v1: str, v2: str, v3: str, v4: str) -> None: | |
| self.eqs.add((v1, v2, v3, v4)) | |
| self.eqs.add((v2, v1, v4, v3)) | |
| self.eqs.add((v3, v4, v1, v2)) | |
| self.eqs.add((v4, v3, v2, v1)) | |
| def check_record_eq(self, v1: str, v2: str, v3: str, v4: str) -> bool: | |
| if (v1, v2, v3, v4) in self.eqs: | |
| return True | |
| if (v2, v1, v4, v3) in self.eqs: | |
| return True | |
| if (v3, v4, v1, v2) in self.eqs: | |
| return True | |
| if (v4, v3, v2, v1) in self.eqs: | |
| return True | |
| return False | |
| def add_eq2( | |
| self, a: str, b: str, m: float, n: float, dep: pr.Dependency | |
| ) -> None: | |
| # a/b = m/n | |
| if not self.add_expr([(a, n), (b, -m)]): | |
| return [] | |
| self.register2(a, b, m, n, dep) | |
| def add_eq3(self, a: str, b: str, f: float, dep: pr.Dependency) -> None: | |
| # a - b = f * constant | |
| self.eqs.add((a, b, frac(f))) | |
| self.eqs.add((b, a, frac(1 - f))) | |
| if not self.add_expr([(a, 1), (b, -1), (self.const, -f)]): | |
| return [] | |
| self.register3(a, b, f, dep) | |
| def add_eq4(self, a: str, b: str, c: str, d: str, dep: pr.Dependency) -> None: | |
| # a - b = c - d | |
| self.record_eq(a, b, c, d) | |
| self.record_eq(a, c, b, d) | |
| expr = list(minus({a: 1, b: -1}, {c: 1, d: -1}).items()) | |
| if not self.add_expr(expr): | |
| return [] | |
| self.register4(a, b, c, d, dep) | |
| self.groups, _, _ = update_groups( | |
| self.groups, [{(a, b), (c, d)}, {(b, a), (d, c)}] | |
| ) | |
| def pairs(self) -> Generator[list[tuple[str, str]], None, None]: | |
| for v1, v2 in perm2(list(self.v2e.keys())): # pylint: disable=g-builtin-op | |
| if v1 == self.const or v2 == self.const: | |
| continue | |
| yield v1, v2 | |
| def modulo(self, e: dict[str, float]) -> dict[str, float]: | |
| return strip(e) | |
| def get_all_eqs( | |
| self, | |
| ) -> dict[tuple[tuple[str, float], ...], list[tuple[str, str]]]: | |
| h2pairs = defaultdict(list) | |
| for v1, v2 in self.pairs(): | |
| e1, e2 = self.v2e[v1], self.v2e[v2] | |
| e12 = minus(e1, e2) | |
| h12 = hashed(self.modulo(e12)) | |
| h2pairs[h12].append((v1, v2)) | |
| return h2pairs | |
| def get_all_eqs_and_why( | |
| self, return_quads: bool = True | |
| ) -> Generator[Any, None, None]: | |
| """Check all 4/3/2-permutations for new equalities.""" | |
| groups = [] | |
| for h, vv in self.get_all_eqs().items(): | |
| if h == (): # pylint: disable=g-explicit-bool-comparison | |
| for v1, v2 in vv: | |
| if (v1, v2) in self.eqs or (v2, v1) in self.eqs: | |
| continue | |
| self.eqs.add((v1, v2)) | |
| # why v1 - v2 = e12 ? (note modulo(e12) == 0) | |
| why_dict = minus({v1: 1, v2: -1}, minus(self.v2e[v1], self.v2e[v2])) | |
| yield v1, v2, self.why(why_dict) | |
| continue | |
| if len(h) == 1 and h[0][0] == self.const: | |
| for v1, v2 in vv: | |
| frac = h[0][1] # pylint: disable=redefined-outer-name | |
| if (v1, v2, frac) in self.eqs: | |
| continue | |
| self.eqs.add((v1, v2, frac)) | |
| # why v1 - v2 = e12 ? (note modulo(e12) == 0) | |
| why_dict = minus({v1: 1, v2: -1}, minus(self.v2e[v1], self.v2e[v2])) | |
| value = simplify(frac.numerator, frac.denominator) | |
| yield v1, v2, value, self.why(why_dict) | |
| continue | |
| groups.append(vv) | |
| if not return_quads: | |
| return | |
| self.groups, links, _ = update_groups(self.groups, groups) | |
| for (v1, v2), (v3, v4) in links: | |
| if self.check_record_eq(v1, v2, v3, v4): | |
| continue | |
| e12 = minus(self.v2e[v1], self.v2e[v2]) | |
| e34 = minus(self.v2e[v3], self.v2e[v4]) | |
| why_dict = minus( # why (v1-v2)-(v3-v4)=e12-e34? | |
| minus({v1: 1, v2: -1}, {v3: 1, v4: -1}), minus(e12, e34) | |
| ) | |
| self.record_eq(v1, v2, v3, v4) | |
| yield v1, v2, v3, v4, self.why(why_dict) | |
| class GeometricTable(Table): | |
| """Abstract class representing the coefficient matrix (table) A.""" | |
| def __init__(self, name: str = ''): | |
| super().__init__(name) | |
| self.v2obj = {} | |
| def get_name(self, objs: list[Any]) -> list[str]: | |
| self.v2obj.update({o.name: o for o in objs}) | |
| return [o.name for o in objs] | |
| def map2obj(self, names: list[str]) -> list[Any]: | |
| return [self.v2obj[n] for n in names] | |
| def get_all_eqs_and_why( | |
| self, return_quads: bool | |
| ) -> Generator[Any, None, None]: | |
| for out in super().get_all_eqs_and_why(return_quads): | |
| if len(out) == 3: | |
| x, y, why = out | |
| x, y = self.map2obj([x, y]) | |
| yield x, y, why | |
| if len(out) == 4: | |
| x, y, f, why = out | |
| x, y = self.map2obj([x, y]) | |
| yield x, y, f, why | |
| if len(out) == 5: | |
| a, b, x, y, why = out | |
| a, b, x, y = self.map2obj([a, b, x, y]) | |
| yield a, b, x, y, why | |
| class RatioTable(GeometricTable): | |
| """Coefficient matrix A for log(distance).""" | |
| def __init__(self, name: str = ''): | |
| name = name or '1' | |
| super().__init__(name) | |
| self.one = self.const | |
| def add_eq(self, l1: gm.Length, l2: gm.Length, dep: pr.Dependency) -> None: | |
| l1, l2 = self.get_name([l1, l2]) | |
| return super().add_eq3(l1, l2, 0.0, dep) | |
| def add_const_ratio( | |
| self, l1: gm.Length, l2: gm.Length, m: float, n: float, dep: pr.Dependency | |
| ) -> None: | |
| l1, l2 = self.get_name([l1, l2]) | |
| return super().add_eq2(l1, l2, m, n, dep) | |
| def add_eqratio( | |
| self, | |
| l1: gm.Length, | |
| l2: gm.Length, | |
| l3: gm.Length, | |
| l4: gm.Length, | |
| dep: pr.Dependency, | |
| ) -> None: | |
| l1, l2, l3, l4 = self.get_name([l1, l2, l3, l4]) | |
| return self.add_eq4(l1, l2, l3, l4, dep) | |
| def get_all_eqs_and_why(self) -> Generator[Any, None, None]: | |
| return super().get_all_eqs_and_why(True) | |
| class AngleTable(GeometricTable): | |
| """Coefficient matrix A for slope(direction).""" | |
| def __init__(self, name: str = ''): | |
| name = name or 'pi' | |
| super().__init__(name) | |
| self.pi = self.const | |
| def modulo(self, e: dict[str, float]) -> dict[str, float]: | |
| e = strip(e) | |
| if self.pi not in e: | |
| return super().modulo(e) | |
| e[self.pi] = e[self.pi] % 1 | |
| return strip(e) | |
| def add_para( | |
| self, d1: gm.Direction, d2: gm.Direction, dep: pr.Dependency | |
| ) -> None: | |
| return self.add_const_angle(d1, d2, 0, dep) | |
| def add_const_angle( | |
| self, d1: gm.Direction, d2: gm.Direction, ang: float, dep: pr.Dependency | |
| ) -> None: | |
| if ang and d2._obj.num > d1._obj.num: # pylint: disable=protected-access | |
| d1, d2 = d2, d1 | |
| ang = 180 - ang | |
| d1, d2 = self.get_name([d1, d2]) | |
| num, den = simplify(ang, 180) | |
| ang = frac(int(num), int(den)) | |
| return super().add_eq3(d1, d2, ang, dep) | |
| def add_eqangle( | |
| self, | |
| d1: gm.Direction, | |
| d2: gm.Direction, | |
| d3: gm.Direction, | |
| d4: gm.Direction, | |
| dep: pr.Dependency, | |
| ) -> None: | |
| """Add the inequality d1-d2=d3-d4.""" | |
| # Use string as variables. | |
| l1, l2, l3, l4 = [d._obj.num for d in [d1, d2, d3, d4]] # pylint: disable=protected-access | |
| d1, d2, d3, d4 = self.get_name([d1, d2, d3, d4]) | |
| ang1 = {d1: 1, d2: -1} | |
| ang2 = {d3: 1, d4: -1} | |
| if l2 > l1: | |
| ang1 = plus({self.pi: 1}, ang1) | |
| if l4 > l3: | |
| ang2 = plus({self.pi: 1}, ang2) | |
| ang12 = minus(ang1, ang2) | |
| self.record_eq(d1, d2, d3, d4) | |
| self.record_eq(d1, d3, d2, d4) | |
| expr = list(ang12.items()) | |
| if not self.add_expr(expr): | |
| return [] | |
| self.register(expr, dep) | |
| def get_all_eqs_and_why(self) -> Generator[Any, None, None]: | |
| return super().get_all_eqs_and_why(True) | |
| class DistanceTable(GeometricTable): | |
| """Coefficient matrix A for position(point, line).""" | |
| def __init__(self, name: str = ''): | |
| name = name or '1:1' | |
| self.merged = {} | |
| self.ratios = set() | |
| super().__init__(name) | |
| def pairs(self) -> Generator[tuple[str, str], None, None]: | |
| l2vs = defaultdict(list) | |
| for v in list(self.v2e.keys()): # pylint: disable=g-builtin-op | |
| if v == self.const: | |
| continue | |
| l, p = v.split(':') | |
| l2vs[l].append(p) | |
| for l, ps in l2vs.items(): | |
| for p1, p2 in perm2(ps): | |
| yield l + ':' + p1, l + ':' + p2 | |
| def name(self, l: gm.Line, p: gm.Point) -> str: | |
| v = l.name + ':' + p.name | |
| self.v2obj[v] = (l, p) | |
| return v | |
| def map2obj(self, names: list[str]) -> list[gm.Point]: | |
| return [self.v2obj[n][1] for n in names] | |
| def add_cong( | |
| self, | |
| l12: gm.Line, | |
| l34: gm.Line, | |
| p1: gm.Point, | |
| p2: gm.Point, | |
| p3: gm.Point, | |
| p4: gm.Point, | |
| dep: pr.Dependency, | |
| ) -> None: | |
| """Add that distance between p1 and p2 (on l12) == p3 and p4 (on l34).""" | |
| if p2.num > p1.num: | |
| p1, p2 = p2, p1 | |
| if p4.num > p3.num: | |
| p3, p4 = p4, p3 | |
| p1 = self.name(l12, p1) | |
| p2 = self.name(l12, p2) | |
| p3 = self.name(l34, p3) | |
| p4 = self.name(l34, p4) | |
| return super().add_eq4(p1, p2, p3, p4, dep) | |
| def get_all_eqs_and_why(self) -> Generator[Any, None, None]: | |
| for x in super().get_all_eqs_and_why(True): | |
| yield x | |
| # Now we figure out all the const ratios. | |
| h2pairs = defaultdict(list) | |
| for v1, v2 in self.pairs(): | |
| if (v1, v2) in self.merged: | |
| continue | |
| e1, e2 = self.v2e[v1], self.v2e[v2] | |
| e12 = minus(e1, e2) | |
| h12 = hashed(e12) | |
| h2pairs[h12].append((v1, v2, e12)) | |
| for (_, vves1), (_, vves2) in perm2(list(h2pairs.items())): | |
| v1, v2, e12 = vves1[0] | |
| for v1_, v2_, _ in vves1[1:]: | |
| self.merged[(v1_, v2_)] = (v1, v2) | |
| v3, v4, e34 = vves2[0] | |
| for v3_, v4_, _ in vves2[1:]: | |
| self.merged[(v3_, v4_)] = (v3, v4) | |
| if (v1, v2, v3, v4) in self.ratios: | |
| continue | |
| d12 = div(e12, e34) | |
| if d12 is None or d12 > 1 or d12 < 0: | |
| continue | |
| self.ratios.add((v1, v2, v3, v4)) | |
| self.ratios.add((v2, v1, v4, v3)) | |
| n, d = d12.numerator, d12.denominator | |
| # (v1 - v2) * d = (v3 - v4) * n | |
| why_dict = minus( | |
| minus({v1: d, v2: -d}, {v3: n, v4: -n}), | |
| minus(mult(e12, d), mult(e34, n)), # there is no modulo, so this is 0 | |
| ) | |
| v1, v2, v3, v4 = self.map2obj([v1, v2, v3, v4]) | |
| yield v1, v2, v3, v4, abs(n), abs(d), self.why(why_dict) | |