|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Implements DAG-level traceback."""
|
|
|
|
|
|
from typing import Any
|
|
|
|
|
|
import geometry as gm
|
|
|
import pretty as pt
|
|
|
import problem
|
|
|
|
|
|
|
|
|
pretty = pt.pretty
|
|
|
|
|
|
|
|
|
def point_levels(
|
|
|
setup: list[problem.Dependency], existing_points: list[gm.Point]
|
|
|
) -> list[tuple[set[gm.Point], list[problem.Dependency]]]:
|
|
|
"""Reformat setup into levels of point constructions."""
|
|
|
levels = []
|
|
|
for con in setup:
|
|
|
plevel = max([p.plevel for p in con.args if isinstance(p, gm.Point)])
|
|
|
|
|
|
while len(levels) - 1 < plevel:
|
|
|
levels.append((set(), []))
|
|
|
|
|
|
for p in con.args:
|
|
|
if not isinstance(p, gm.Point):
|
|
|
continue
|
|
|
if existing_points and p in existing_points:
|
|
|
continue
|
|
|
|
|
|
levels[p.plevel][0].add(p)
|
|
|
|
|
|
cons = levels[plevel][1]
|
|
|
cons.append(con)
|
|
|
|
|
|
return [(p, c) for p, c in levels if p or c]
|
|
|
|
|
|
|
|
|
def point_log(
|
|
|
setup: list[problem.Dependency],
|
|
|
ref_id: dict[tuple[str, ...], int],
|
|
|
existing_points=list[gm.Point],
|
|
|
) -> list[tuple[list[gm.Point], list[problem.Dependency]]]:
|
|
|
"""Reformat setup into groups of point constructions."""
|
|
|
log = []
|
|
|
|
|
|
levels = point_levels(setup, existing_points)
|
|
|
|
|
|
for points, cons in levels:
|
|
|
for con in cons:
|
|
|
if con.hashed() not in ref_id:
|
|
|
ref_id[con.hashed()] = len(ref_id)
|
|
|
|
|
|
log.append((points, cons))
|
|
|
|
|
|
return log
|
|
|
|
|
|
|
|
|
def setup_to_levels(
|
|
|
setup: list[problem.Dependency],
|
|
|
) -> list[list[problem.Dependency]]:
|
|
|
"""Reformat setup into levels of point constructions."""
|
|
|
levels = []
|
|
|
for d in setup:
|
|
|
plevel = max([p.plevel for p in d.args if isinstance(p, gm.Point)])
|
|
|
while len(levels) - 1 < plevel:
|
|
|
levels.append([])
|
|
|
|
|
|
levels[plevel].append(d)
|
|
|
|
|
|
levels = [lvl for lvl in levels if lvl]
|
|
|
return levels
|
|
|
|
|
|
|
|
|
def separate_dependency_difference(
|
|
|
query: problem.Dependency,
|
|
|
log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
|
|
|
) -> tuple[
|
|
|
list[tuple[list[problem.Dependency], list[problem.Dependency]]],
|
|
|
list[problem.Dependency],
|
|
|
list[problem.Dependency],
|
|
|
set[gm.Point],
|
|
|
set[gm.Point],
|
|
|
]:
|
|
|
"""Identify and separate the dependency difference."""
|
|
|
setup = []
|
|
|
log_, log = log, []
|
|
|
for prems, cons in log_:
|
|
|
if not prems:
|
|
|
setup.extend(cons)
|
|
|
continue
|
|
|
cons_ = []
|
|
|
for con in cons:
|
|
|
if con.rule_name == 'c0':
|
|
|
setup.append(con)
|
|
|
else:
|
|
|
cons_.append(con)
|
|
|
if not cons_:
|
|
|
continue
|
|
|
|
|
|
prems = [p for p in prems if p.name != 'ind']
|
|
|
log.append((prems, cons_))
|
|
|
|
|
|
points = set(query.args)
|
|
|
queue = list(query.args)
|
|
|
i = 0
|
|
|
while i < len(queue):
|
|
|
q = queue[i]
|
|
|
i += 1
|
|
|
if not isinstance(q, gm.Point):
|
|
|
continue
|
|
|
for p in q.rely_on:
|
|
|
if p not in points:
|
|
|
points.add(p)
|
|
|
queue.append(p)
|
|
|
|
|
|
setup_, setup, aux_setup, aux_points = setup, [], [], set()
|
|
|
for con in setup_:
|
|
|
if con.name == 'ind':
|
|
|
continue
|
|
|
elif any([p not in points for p in con.args if isinstance(p, gm.Point)]):
|
|
|
aux_setup.append(con)
|
|
|
aux_points.update(
|
|
|
[p for p in con.args if isinstance(p, gm.Point) and p not in points]
|
|
|
)
|
|
|
else:
|
|
|
setup.append(con)
|
|
|
|
|
|
return log, setup, aux_setup, points, aux_points
|
|
|
|
|
|
|
|
|
def recursive_traceback(
|
|
|
query: problem.Dependency,
|
|
|
) -> list[tuple[list[problem.Dependency], list[problem.Dependency]]]:
|
|
|
"""Recursively traceback from the query, i.e. the conclusion."""
|
|
|
visited = set()
|
|
|
log = []
|
|
|
stack = []
|
|
|
|
|
|
def read(q: problem.Dependency) -> None:
|
|
|
q = q.remove_loop()
|
|
|
hashed = q.hashed()
|
|
|
if hashed in visited:
|
|
|
return
|
|
|
|
|
|
if hashed[0] in ['ncoll', 'npara', 'nperp', 'diff', 'sameside']:
|
|
|
return
|
|
|
|
|
|
nonlocal stack
|
|
|
|
|
|
stack.append(hashed)
|
|
|
prems = []
|
|
|
|
|
|
if q.rule_name != problem.CONSTRUCTION_RULE:
|
|
|
all_deps = []
|
|
|
dep_names = set()
|
|
|
for d in q.why:
|
|
|
if d.hashed() in dep_names:
|
|
|
continue
|
|
|
dep_names.add(d.hashed())
|
|
|
all_deps.append(d)
|
|
|
|
|
|
for d in all_deps:
|
|
|
h = d.hashed()
|
|
|
if h not in visited:
|
|
|
read(d)
|
|
|
if h in visited:
|
|
|
prems.append(d)
|
|
|
|
|
|
visited.add(hashed)
|
|
|
hashs = sorted([d.hashed() for d in prems])
|
|
|
found = False
|
|
|
for ps, qs in log:
|
|
|
if sorted([d.hashed() for d in ps]) == hashs:
|
|
|
qs += [q]
|
|
|
found = True
|
|
|
break
|
|
|
if not found:
|
|
|
log.append((prems, [q]))
|
|
|
|
|
|
stack.pop(-1)
|
|
|
|
|
|
read(query)
|
|
|
|
|
|
|
|
|
log_, log = log, []
|
|
|
for ps, qs in log_:
|
|
|
for q in qs:
|
|
|
log.append((ps, [q]))
|
|
|
|
|
|
return log
|
|
|
|
|
|
|
|
|
def collx_to_coll_setup(
|
|
|
setup: list[problem.Dependency],
|
|
|
) -> list[problem.Dependency]:
|
|
|
"""Convert collx to coll in setups."""
|
|
|
result = []
|
|
|
for level in setup_to_levels(setup):
|
|
|
hashs = set()
|
|
|
for dep in level:
|
|
|
if dep.name == 'collx':
|
|
|
dep.name = 'coll'
|
|
|
dep.args = list(set(dep.args))
|
|
|
|
|
|
if dep.hashed() in hashs:
|
|
|
continue
|
|
|
hashs.add(dep.hashed())
|
|
|
result.append(dep)
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
def collx_to_coll(
|
|
|
setup: list[problem.Dependency],
|
|
|
aux_setup: list[problem.Dependency],
|
|
|
log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
|
|
|
) -> tuple[
|
|
|
list[problem.Dependency],
|
|
|
list[problem.Dependency],
|
|
|
list[tuple[list[problem.Dependency], list[problem.Dependency]]],
|
|
|
]:
|
|
|
"""Convert collx to coll and dedup."""
|
|
|
setup = collx_to_coll_setup(setup)
|
|
|
aux_setup = collx_to_coll_setup(aux_setup)
|
|
|
|
|
|
con_set = set([p.hashed() for p in setup + aux_setup])
|
|
|
log_, log = log, []
|
|
|
for prems, cons in log_:
|
|
|
prem_set = set()
|
|
|
prems_, prems = prems, []
|
|
|
for p in prems_:
|
|
|
if p.name == 'collx':
|
|
|
p.name = 'coll'
|
|
|
p.args = list(set(p.args))
|
|
|
if p.hashed() in prem_set:
|
|
|
continue
|
|
|
prem_set.add(p.hashed())
|
|
|
prems.append(p)
|
|
|
|
|
|
cons_, cons = cons, []
|
|
|
for c in cons_:
|
|
|
if c.name == 'collx':
|
|
|
c.name = 'coll'
|
|
|
c.args = list(set(c.args))
|
|
|
if c.hashed() in con_set:
|
|
|
continue
|
|
|
con_set.add(c.hashed())
|
|
|
cons.append(c)
|
|
|
|
|
|
if not cons or not prems:
|
|
|
continue
|
|
|
|
|
|
log.append((prems, cons))
|
|
|
|
|
|
return setup, aux_setup, log
|
|
|
|
|
|
|
|
|
def get_logs(
|
|
|
query: problem.Dependency, g: Any, merge_trivials: bool = False
|
|
|
) -> tuple[
|
|
|
list[problem.Dependency],
|
|
|
list[problem.Dependency],
|
|
|
list[tuple[list[problem.Dependency], list[problem.Dependency]]],
|
|
|
set[gm.Point],
|
|
|
]:
|
|
|
"""Given a DAG and conclusion N, return the premise, aux, proof."""
|
|
|
query = query.why_me_or_cache(g, query.level)
|
|
|
log = recursive_traceback(query)
|
|
|
log, setup, aux_setup, setup_points, _ = separate_dependency_difference(
|
|
|
query, log
|
|
|
)
|
|
|
|
|
|
setup, aux_setup, log = collx_to_coll(setup, aux_setup, log)
|
|
|
|
|
|
setup, aux_setup, log = shorten_and_shave(
|
|
|
setup, aux_setup, log, merge_trivials
|
|
|
)
|
|
|
|
|
|
return setup, aux_setup, log, setup_points
|
|
|
|
|
|
|
|
|
def shorten_and_shave(
|
|
|
setup: list[problem.Dependency],
|
|
|
aux_setup: list[problem.Dependency],
|
|
|
log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
|
|
|
merge_trivials: bool = False,
|
|
|
) -> tuple[
|
|
|
list[problem.Dependency],
|
|
|
list[problem.Dependency],
|
|
|
list[tuple[list[problem.Dependency], list[problem.Dependency]]],
|
|
|
]:
|
|
|
"""Shorten the proof by removing unused predicates."""
|
|
|
log, _ = shorten_proof(log, merge_trivials=merge_trivials)
|
|
|
|
|
|
all_prems = sum([list(prems) for prems, _ in log], [])
|
|
|
all_prems = set([p.hashed() for p in all_prems])
|
|
|
setup = [d for d in setup if d.hashed() in all_prems]
|
|
|
aux_setup = [d for d in aux_setup if d.hashed() in all_prems]
|
|
|
return setup, aux_setup, log
|
|
|
|
|
|
|
|
|
def join_prems(
|
|
|
con: problem.Dependency,
|
|
|
con2prems: dict[tuple[str, ...], list[problem.Dependency]],
|
|
|
expanded: set[tuple[str, ...]],
|
|
|
) -> list[problem.Dependency]:
|
|
|
"""Join proof steps with the same premises."""
|
|
|
h = con.hashed()
|
|
|
if h in expanded or h not in con2prems:
|
|
|
return [con]
|
|
|
|
|
|
result = []
|
|
|
for p in con2prems[h]:
|
|
|
result += join_prems(p, con2prems, expanded)
|
|
|
return result
|
|
|
|
|
|
|
|
|
def shorten_proof(
|
|
|
log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
|
|
|
merge_trivials: bool = False,
|
|
|
) -> tuple[
|
|
|
list[tuple[list[problem.Dependency], list[problem.Dependency]]],
|
|
|
dict[tuple[str, ...], list[problem.Dependency]],
|
|
|
]:
|
|
|
"""Join multiple trivials proof steps into one."""
|
|
|
pops = set()
|
|
|
con2prem = {}
|
|
|
for prems, cons in log:
|
|
|
assert len(cons) == 1
|
|
|
con = cons[0]
|
|
|
if con.rule_name == '':
|
|
|
con2prem[con.hashed()] = prems
|
|
|
elif not merge_trivials:
|
|
|
|
|
|
pops.update({p.hashed() for p in prems})
|
|
|
|
|
|
for p in pops:
|
|
|
if p in con2prem:
|
|
|
con2prem.pop(p)
|
|
|
|
|
|
expanded = set()
|
|
|
log2 = []
|
|
|
for i, (prems, cons) in enumerate(log):
|
|
|
con = cons[0]
|
|
|
if i < len(log) - 1 and con.hashed() in con2prem:
|
|
|
continue
|
|
|
|
|
|
hashs = set()
|
|
|
new_prems = []
|
|
|
|
|
|
for p in sum([join_prems(p, con2prem, expanded) for p in prems], []):
|
|
|
if p.hashed() not in hashs:
|
|
|
new_prems.append(p)
|
|
|
hashs.add(p.hashed())
|
|
|
|
|
|
log2 += [(new_prems, [con])]
|
|
|
expanded.add(con.hashed())
|
|
|
|
|
|
return log2, con2prem
|
|
|
|