Spaces:
Runtime error
Runtime error
| # Copyright 2023 DeepMind Technologies Limited | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Implements DAG-level traceback.""" | |
| from typing import Any | |
| import geometry as gm | |
| import pretty as pt | |
| import problem | |
| pretty = pt.pretty | |
| def point_levels( | |
| setup: list[problem.Dependency], existing_points: list[gm.Point] | |
| ) -> list[tuple[set[gm.Point], list[problem.Dependency]]]: | |
| """Reformat setup into levels of point constructions.""" | |
| levels = [] | |
| for con in setup: | |
| plevel = max([p.plevel for p in con.args if isinstance(p, gm.Point)]) | |
| while len(levels) - 1 < plevel: | |
| levels.append((set(), [])) | |
| for p in con.args: | |
| if not isinstance(p, gm.Point): | |
| continue | |
| if existing_points and p in existing_points: | |
| continue | |
| levels[p.plevel][0].add(p) | |
| cons = levels[plevel][1] | |
| cons.append(con) | |
| return [(p, c) for p, c in levels if p or c] | |
| def point_log( | |
| setup: list[problem.Dependency], | |
| ref_id: dict[tuple[str, ...], int], | |
| existing_points=list[gm.Point], | |
| ) -> list[tuple[list[gm.Point], list[problem.Dependency]]]: | |
| """Reformat setup into groups of point constructions.""" | |
| log = [] | |
| levels = point_levels(setup, existing_points) | |
| for points, cons in levels: | |
| for con in cons: | |
| if con.hashed() not in ref_id: | |
| ref_id[con.hashed()] = len(ref_id) | |
| log.append((points, cons)) | |
| return log | |
| def setup_to_levels( | |
| setup: list[problem.Dependency], | |
| ) -> list[list[problem.Dependency]]: | |
| """Reformat setup into levels of point constructions.""" | |
| levels = [] | |
| for d in setup: | |
| plevel = max([p.plevel for p in d.args if isinstance(p, gm.Point)]) | |
| while len(levels) - 1 < plevel: | |
| levels.append([]) | |
| levels[plevel].append(d) | |
| levels = [lvl for lvl in levels if lvl] | |
| return levels | |
| def separate_dependency_difference( | |
| query: problem.Dependency, | |
| log: list[tuple[list[problem.Dependency], list[problem.Dependency]]], | |
| ) -> tuple[ | |
| list[tuple[list[problem.Dependency], list[problem.Dependency]]], | |
| list[problem.Dependency], | |
| list[problem.Dependency], | |
| set[gm.Point], | |
| set[gm.Point], | |
| ]: | |
| """Identify and separate the dependency difference.""" | |
| setup = [] | |
| log_, log = log, [] | |
| for prems, cons in log_: | |
| if not prems: | |
| setup.extend(cons) | |
| continue | |
| cons_ = [] | |
| for con in cons: | |
| if con.rule_name == 'c0': | |
| setup.append(con) | |
| else: | |
| cons_.append(con) | |
| if not cons_: | |
| continue | |
| prems = [p for p in prems if p.name != 'ind'] | |
| log.append((prems, cons_)) | |
| points = set(query.args) | |
| queue = list(query.args) | |
| i = 0 | |
| while i < len(queue): | |
| q = queue[i] | |
| i += 1 | |
| if not isinstance(q, gm.Point): | |
| continue | |
| for p in q.rely_on: | |
| if p not in points: | |
| points.add(p) | |
| queue.append(p) | |
| setup_, setup, aux_setup, aux_points = setup, [], [], set() | |
| for con in setup_: | |
| if con.name == 'ind': | |
| continue | |
| elif any([p not in points for p in con.args if isinstance(p, gm.Point)]): | |
| aux_setup.append(con) | |
| aux_points.update( | |
| [p for p in con.args if isinstance(p, gm.Point) and p not in points] | |
| ) | |
| else: | |
| setup.append(con) | |
| return log, setup, aux_setup, points, aux_points | |
| def recursive_traceback( | |
| query: problem.Dependency, | |
| ) -> list[tuple[list[problem.Dependency], list[problem.Dependency]]]: | |
| """Recursively traceback from the query, i.e. the conclusion.""" | |
| visited = set() | |
| log = [] | |
| stack = [] | |
| def read(q: problem.Dependency) -> None: | |
| q = q.remove_loop() | |
| hashed = q.hashed() | |
| if hashed in visited: | |
| return | |
| if hashed[0] in ['ncoll', 'npara', 'nperp', 'diff', 'sameside']: | |
| return | |
| nonlocal stack | |
| stack.append(hashed) | |
| prems = [] | |
| if q.rule_name != problem.CONSTRUCTION_RULE: | |
| all_deps = [] | |
| dep_names = set() | |
| for d in q.why: | |
| if d.hashed() in dep_names: | |
| continue | |
| dep_names.add(d.hashed()) | |
| all_deps.append(d) | |
| for d in all_deps: | |
| h = d.hashed() | |
| if h not in visited: | |
| read(d) | |
| if h in visited: | |
| prems.append(d) | |
| visited.add(hashed) | |
| hashs = sorted([d.hashed() for d in prems]) | |
| found = False | |
| for ps, qs in log: | |
| if sorted([d.hashed() for d in ps]) == hashs: | |
| qs += [q] | |
| found = True | |
| break | |
| if not found: | |
| log.append((prems, [q])) | |
| stack.pop(-1) | |
| read(query) | |
| # post process log: separate multi-conclusion lines | |
| log_, log = log, [] | |
| for ps, qs in log_: | |
| for q in qs: | |
| log.append((ps, [q])) | |
| return log | |
| def collx_to_coll_setup( | |
| setup: list[problem.Dependency], | |
| ) -> list[problem.Dependency]: | |
| """Convert collx to coll in setups.""" | |
| result = [] | |
| for level in setup_to_levels(setup): | |
| hashs = set() | |
| for dep in level: | |
| if dep.name == 'collx': | |
| dep.name = 'coll' | |
| dep.args = list(set(dep.args)) | |
| if dep.hashed() in hashs: | |
| continue | |
| hashs.add(dep.hashed()) | |
| result.append(dep) | |
| return result | |
| def collx_to_coll( | |
| setup: list[problem.Dependency], | |
| aux_setup: list[problem.Dependency], | |
| log: list[tuple[list[problem.Dependency], list[problem.Dependency]]], | |
| ) -> tuple[ | |
| list[problem.Dependency], | |
| list[problem.Dependency], | |
| list[tuple[list[problem.Dependency], list[problem.Dependency]]], | |
| ]: | |
| """Convert collx to coll and dedup.""" | |
| setup = collx_to_coll_setup(setup) | |
| aux_setup = collx_to_coll_setup(aux_setup) | |
| con_set = set([p.hashed() for p in setup + aux_setup]) | |
| log_, log = log, [] | |
| for prems, cons in log_: | |
| prem_set = set() | |
| prems_, prems = prems, [] | |
| for p in prems_: | |
| if p.name == 'collx': | |
| p.name = 'coll' | |
| p.args = list(set(p.args)) | |
| if p.hashed() in prem_set: | |
| continue | |
| prem_set.add(p.hashed()) | |
| prems.append(p) | |
| cons_, cons = cons, [] | |
| for c in cons_: | |
| if c.name == 'collx': | |
| c.name = 'coll' | |
| c.args = list(set(c.args)) | |
| if c.hashed() in con_set: | |
| continue | |
| con_set.add(c.hashed()) | |
| cons.append(c) | |
| if not cons or not prems: | |
| continue | |
| log.append((prems, cons)) | |
| return setup, aux_setup, log | |
| def get_logs( | |
| query: problem.Dependency, g: Any, merge_trivials: bool = False | |
| ) -> tuple[ | |
| list[problem.Dependency], | |
| list[problem.Dependency], | |
| list[tuple[list[problem.Dependency], list[problem.Dependency]]], | |
| set[gm.Point], | |
| ]: | |
| """Given a DAG and conclusion N, return the premise, aux, proof.""" | |
| query = query.why_me_or_cache(g, query.level) | |
| log = recursive_traceback(query) | |
| log, setup, aux_setup, setup_points, _ = separate_dependency_difference( | |
| query, log | |
| ) | |
| setup, aux_setup, log = collx_to_coll(setup, aux_setup, log) | |
| setup, aux_setup, log = shorten_and_shave( | |
| setup, aux_setup, log, merge_trivials | |
| ) | |
| return setup, aux_setup, log, setup_points | |
| def shorten_and_shave( | |
| setup: list[problem.Dependency], | |
| aux_setup: list[problem.Dependency], | |
| log: list[tuple[list[problem.Dependency], list[problem.Dependency]]], | |
| merge_trivials: bool = False, | |
| ) -> tuple[ | |
| list[problem.Dependency], | |
| list[problem.Dependency], | |
| list[tuple[list[problem.Dependency], list[problem.Dependency]]], | |
| ]: | |
| """Shorten the proof by removing unused predicates.""" | |
| log, _ = shorten_proof(log, merge_trivials=merge_trivials) | |
| all_prems = sum([list(prems) for prems, _ in log], []) | |
| all_prems = set([p.hashed() for p in all_prems]) | |
| setup = [d for d in setup if d.hashed() in all_prems] | |
| aux_setup = [d for d in aux_setup if d.hashed() in all_prems] | |
| return setup, aux_setup, log | |
| def join_prems( | |
| con: problem.Dependency, | |
| con2prems: dict[tuple[str, ...], list[problem.Dependency]], | |
| expanded: set[tuple[str, ...]], | |
| ) -> list[problem.Dependency]: | |
| """Join proof steps with the same premises.""" | |
| h = con.hashed() | |
| if h in expanded or h not in con2prems: | |
| return [con] | |
| result = [] | |
| for p in con2prems[h]: | |
| result += join_prems(p, con2prems, expanded) | |
| return result | |
| def shorten_proof( | |
| log: list[tuple[list[problem.Dependency], list[problem.Dependency]]], | |
| merge_trivials: bool = False, | |
| ) -> tuple[ | |
| list[tuple[list[problem.Dependency], list[problem.Dependency]]], | |
| dict[tuple[str, ...], list[problem.Dependency]], | |
| ]: | |
| """Join multiple trivials proof steps into one.""" | |
| pops = set() | |
| con2prem = {} | |
| for prems, cons in log: | |
| assert len(cons) == 1 | |
| con = cons[0] | |
| if con.rule_name == '': # pylint: disable=g-explicit-bool-comparison | |
| con2prem[con.hashed()] = prems | |
| elif not merge_trivials: | |
| # except for the ones that are premises to non-trivial steps. | |
| pops.update({p.hashed() for p in prems}) | |
| for p in pops: | |
| if p in con2prem: | |
| con2prem.pop(p) | |
| expanded = set() | |
| log2 = [] | |
| for i, (prems, cons) in enumerate(log): | |
| con = cons[0] | |
| if i < len(log) - 1 and con.hashed() in con2prem: | |
| continue | |
| hashs = set() | |
| new_prems = [] | |
| for p in sum([join_prems(p, con2prem, expanded) for p in prems], []): | |
| if p.hashed() not in hashs: | |
| new_prems.append(p) | |
| hashs.add(p.hashed()) | |
| log2 += [(new_prems, [con])] | |
| expanded.add(con.hashed()) | |
| return log2, con2prem | |