|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
Implements DAG-level traceback with advanced analytics, visualization, provenance, distributed computation, and extensibility.
|
|
|
"""
|
|
|
|
|
|
|
|
|
import logging
|
|
|
import threading
|
|
|
import concurrent.futures
|
|
|
import queue
|
|
|
import time
|
|
|
import json
|
|
|
from typing import Callable, Optional, Dict, Any
|
|
|
import matplotlib.pyplot as plt
|
|
|
import networkx as nx
|
|
|
import uuid
|
|
|
|
|
|
def traceback_statistics(log: list[tuple[list[Any], list[Any]]]) -> Dict[str, Any]:
|
|
|
"""Compute statistics about the traceback log."""
|
|
|
num_steps = len(log)
|
|
|
num_unique_prems = len(set([p.hashed() for prems, _ in log for p in prems]))
|
|
|
num_unique_cons = len(set([c.hashed() for _, cons in log for c in cons]))
|
|
|
return {
|
|
|
"num_steps": num_steps,
|
|
|
"num_unique_prems": num_unique_prems,
|
|
|
"num_unique_cons": num_unique_cons,
|
|
|
}
|
|
|
|
|
|
def export_traceback_provenance(log: list[tuple[list[Any], list[Any]]], file_path: str):
|
|
|
"""Export provenance of the traceback to a JSON file."""
|
|
|
provenance = [
|
|
|
{
|
|
|
"prems": [p.hashed() for p in prems],
|
|
|
"cons": [c.hashed() for c in cons],
|
|
|
}
|
|
|
for prems, cons in log
|
|
|
]
|
|
|
with open(file_path, "w", encoding="utf-8") as f:
|
|
|
json.dump(provenance, f, indent=2)
|
|
|
|
|
|
|
|
|
def visualize_traceback_graph(log: list[tuple[list[Any], list[Any]]], show: bool = True, save_path: Optional[str] = None):
|
|
|
"""Visualize the traceback as a DAG using networkx and matplotlib."""
|
|
|
G = nx.DiGraph()
|
|
|
for prems, cons in log:
|
|
|
for c in cons:
|
|
|
for p in prems:
|
|
|
G.add_edge(p.hashed(), c.hashed())
|
|
|
plt.figure(figsize=(12, 8))
|
|
|
pos = nx.spring_layout(G)
|
|
|
nx.draw(G, pos, with_labels=True, node_size=500, font_size=8)
|
|
|
if save_path:
|
|
|
plt.savefig(save_path)
|
|
|
if show:
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
def parallel_recursive_traceback(queries: list[Any], max_workers: int = 4) -> Dict[str, list[tuple[list[Any], list[Any]]]]:
|
|
|
"""Compute recursive traceback for multiple queries in parallel."""
|
|
|
results = {}
|
|
|
def worker(q):
|
|
|
return q.hashed(), recursive_traceback(q)
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
|
future_to_query = {executor.submit(worker, q): q for q in queries}
|
|
|
for future in concurrent.futures.as_completed(future_to_query):
|
|
|
h, log = future.result()
|
|
|
results[h] = log
|
|
|
return results
|
|
|
|
|
|
|
|
|
class TracebackStreamer:
|
|
|
"""Streams traceback steps to listeners in real time."""
|
|
|
def __init__(self):
|
|
|
self.listeners = []
|
|
|
self.q = queue.Queue()
|
|
|
self.running = False
|
|
|
def add_listener(self, callback: Callable[[Any], None]):
|
|
|
self.listeners.append(callback)
|
|
|
def stream(self, log: list[tuple[list[Any], list[Any]]]):
|
|
|
self.running = True
|
|
|
def run():
|
|
|
for step in log:
|
|
|
self.q.put(step)
|
|
|
time.sleep(0.05)
|
|
|
self.running = False
|
|
|
threading.Thread(target=run, daemon=True).start()
|
|
|
while self.running or not self.q.empty():
|
|
|
try:
|
|
|
step = self.q.get(timeout=0.1)
|
|
|
for cb in self.listeners:
|
|
|
cb(step)
|
|
|
except queue.Empty:
|
|
|
continue
|
|
|
|
|
|
|
|
|
class TracebackPlugin:
|
|
|
"""Base class for custom traceback analyzers."""
|
|
|
def analyze(self, log: list[tuple[list[Any], list[Any]]]) -> Any:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
class TracebackPluginManager:
|
|
|
def __init__(self):
|
|
|
self.plugins: Dict[str, TracebackPlugin] = {}
|
|
|
def register(self, name: str, plugin: TracebackPlugin):
|
|
|
self.plugins[name] = plugin
|
|
|
def run_all(self, log: list[tuple[list[Any], list[Any]]]) -> Dict[str, Any]:
|
|
|
return {name: plugin.analyze(log) for name, plugin in self.plugins.items()}
|
|
|
|
|
|
|
|
|
def integrate_external_prover(log: list[tuple[list[Any], list[Any]]], prover_api: Callable):
|
|
|
"""Send traceback steps to an external proof engine for validation or augmentation."""
|
|
|
for prems, cons in log:
|
|
|
prover_api({"prems": prems, "cons": cons})
|
|
|
|
|
|
|
|
|
def safe_traceback(query: Any) -> Optional[list[tuple[list[Any], list[Any]]]]:
|
|
|
try:
|
|
|
return recursive_traceback(query)
|
|
|
except Exception as e:
|
|
|
logging.error(f"Traceback failed: {e}", exc_info=True)
|
|
|
return None
|
|
|
|
|
|
|
|
|
def test_traceback_module():
|
|
|
import random
|
|
|
class DummyDep:
|
|
|
def __init__(self, name):
|
|
|
self._name = name
|
|
|
def hashed(self):
|
|
|
return self._name + str(uuid.uuid4())
|
|
|
@property
|
|
|
def rule_name(self):
|
|
|
return random.choice(['', 'c0', 'collx', 'coll'])
|
|
|
@property
|
|
|
def why(self):
|
|
|
return []
|
|
|
def remove_loop(self):
|
|
|
return self
|
|
|
|
|
|
queries = [DummyDep(f"Q{i}") for i in range(5)]
|
|
|
logs = parallel_recursive_traceback(queries)
|
|
|
for h, log in logs.items():
|
|
|
stats = traceback_statistics(log)
|
|
|
print(f"Traceback {h}: {stats}")
|
|
|
visualize_traceback_graph(log, show=False)
|
|
|
|
|
|
streamer = TracebackStreamer()
|
|
|
streamer.add_listener(lambda step: print(f"Streamed step: {step}"))
|
|
|
for log in logs.values():
|
|
|
streamer.stream(log)
|
|
|
|
|
|
class StepCountPlugin(TracebackPlugin):
|
|
|
def analyze(self, log):
|
|
|
return len(log)
|
|
|
pm = TracebackPluginManager()
|
|
|
pm.register("step_count", StepCountPlugin())
|
|
|
for log in logs.values():
|
|
|
print("Plugin results:", pm.run_all(log))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
test_traceback_module()
|
|
|
|
|
|
from typing import Any
|
|
|
|
|
|
import geometry as gm
|
|
|
import pretty as pt
|
|
|
import problem
|
|
|
|
|
|
|
|
|
pretty = pt.pretty
|
|
|
|
|
|
|
|
|
def point_levels(
|
|
|
setup: list[problem.Dependency], existing_points: list[gm.Point]
|
|
|
) -> list[tuple[set[gm.Point], list[problem.Dependency]]]:
|
|
|
"""Reformat setup into levels of point constructions."""
|
|
|
levels = []
|
|
|
for con in setup:
|
|
|
plevel = max([p.plevel for p in con.args if isinstance(p, gm.Point)])
|
|
|
|
|
|
while len(levels) - 1 < plevel:
|
|
|
levels.append((set(), []))
|
|
|
|
|
|
for p in con.args:
|
|
|
if not isinstance(p, gm.Point):
|
|
|
continue
|
|
|
if existing_points and p in existing_points:
|
|
|
continue
|
|
|
|
|
|
levels[p.plevel][0].add(p)
|
|
|
|
|
|
cons = levels[plevel][1]
|
|
|
cons.append(con)
|
|
|
|
|
|
return [(p, c) for p, c in levels if p or c]
|
|
|
|
|
|
|
|
|
def point_log(
|
|
|
setup: list[problem.Dependency],
|
|
|
ref_id: dict[tuple[str, ...], int],
|
|
|
existing_points=list[gm.Point],
|
|
|
) -> list[tuple[list[gm.Point], list[problem.Dependency]]]:
|
|
|
"""Reformat setup into groups of point constructions."""
|
|
|
log = []
|
|
|
|
|
|
levels = point_levels(setup, existing_points)
|
|
|
|
|
|
for points, cons in levels:
|
|
|
for con in cons:
|
|
|
if con.hashed() not in ref_id:
|
|
|
ref_id[con.hashed()] = len(ref_id)
|
|
|
|
|
|
log.append((points, cons))
|
|
|
|
|
|
return log
|
|
|
|
|
|
|
|
|
def setup_to_levels(
|
|
|
setup: list[problem.Dependency],
|
|
|
) -> list[list[problem.Dependency]]:
|
|
|
"""Reformat setup into levels of point constructions."""
|
|
|
levels = []
|
|
|
for d in setup:
|
|
|
plevel = max([p.plevel for p in d.args if isinstance(p, gm.Point)])
|
|
|
while len(levels) - 1 < plevel:
|
|
|
levels.append([])
|
|
|
|
|
|
levels[plevel].append(d)
|
|
|
|
|
|
levels = [lvl for lvl in levels if lvl]
|
|
|
return levels
|
|
|
|
|
|
|
|
|
def separate_dependency_difference(
|
|
|
query: problem.Dependency,
|
|
|
log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
|
|
|
) -> tuple[
|
|
|
list[tuple[list[problem.Dependency], list[problem.Dependency]]],
|
|
|
list[problem.Dependency],
|
|
|
list[problem.Dependency],
|
|
|
set[gm.Point],
|
|
|
set[gm.Point],
|
|
|
]:
|
|
|
"""Identify and separate the dependency difference."""
|
|
|
setup = []
|
|
|
log_, log = log, []
|
|
|
for prems, cons in log_:
|
|
|
if not prems:
|
|
|
setup.extend(cons)
|
|
|
continue
|
|
|
cons_ = []
|
|
|
for con in cons:
|
|
|
if con.rule_name == 'c0':
|
|
|
setup.append(con)
|
|
|
else:
|
|
|
cons_.append(con)
|
|
|
if not cons_:
|
|
|
continue
|
|
|
|
|
|
prems = [p for p in prems if p.name != 'ind']
|
|
|
log.append((prems, cons_))
|
|
|
|
|
|
points = set(query.args)
|
|
|
queue = list(query.args)
|
|
|
i = 0
|
|
|
while i < len(queue):
|
|
|
q = queue[i]
|
|
|
i += 1
|
|
|
if not isinstance(q, gm.Point):
|
|
|
continue
|
|
|
for p in q.rely_on:
|
|
|
if p not in points:
|
|
|
points.add(p)
|
|
|
queue.append(p)
|
|
|
|
|
|
setup_, setup, aux_setup, aux_points = setup, [], [], set()
|
|
|
for con in setup_:
|
|
|
if con.name == 'ind':
|
|
|
continue
|
|
|
elif any([p not in points for p in con.args if isinstance(p, gm.Point)]):
|
|
|
aux_setup.append(con)
|
|
|
aux_points.update(
|
|
|
[p for p in con.args if isinstance(p, gm.Point) and p not in points]
|
|
|
)
|
|
|
else:
|
|
|
setup.append(con)
|
|
|
|
|
|
return log, setup, aux_setup, points, aux_points
|
|
|
|
|
|
|
|
|
def recursive_traceback(
|
|
|
query: problem.Dependency,
|
|
|
) -> list[tuple[list[problem.Dependency], list[problem.Dependency]]]:
|
|
|
"""Recursively traceback from the query, i.e. the conclusion."""
|
|
|
visited = set()
|
|
|
log = []
|
|
|
stack = []
|
|
|
|
|
|
def read(q: problem.Dependency) -> None:
|
|
|
q = q.remove_loop()
|
|
|
hashed = q.hashed()
|
|
|
if hashed in visited:
|
|
|
return
|
|
|
|
|
|
if hashed[0] in ['ncoll', 'npara', 'nperp', 'diff', 'sameside']:
|
|
|
return
|
|
|
|
|
|
nonlocal stack
|
|
|
|
|
|
stack.append(hashed)
|
|
|
prems = []
|
|
|
|
|
|
if q.rule_name != problem.CONSTRUCTION_RULE:
|
|
|
all_deps = []
|
|
|
dep_names = set()
|
|
|
for d in q.why:
|
|
|
if d.hashed() in dep_names:
|
|
|
continue
|
|
|
dep_names.add(d.hashed())
|
|
|
all_deps.append(d)
|
|
|
|
|
|
for d in all_deps:
|
|
|
h = d.hashed()
|
|
|
if h not in visited:
|
|
|
read(d)
|
|
|
if h in visited:
|
|
|
prems.append(d)
|
|
|
|
|
|
visited.add(hashed)
|
|
|
hashs = sorted([d.hashed() for d in prems])
|
|
|
found = False
|
|
|
for ps, qs in log:
|
|
|
if sorted([d.hashed() for d in ps]) == hashs:
|
|
|
qs += [q]
|
|
|
found = True
|
|
|
break
|
|
|
if not found:
|
|
|
log.append((prems, [q]))
|
|
|
|
|
|
stack.pop(-1)
|
|
|
|
|
|
read(query)
|
|
|
|
|
|
|
|
|
log_, log = log, []
|
|
|
for ps, qs in log_:
|
|
|
for q in qs:
|
|
|
log.append((ps, [q]))
|
|
|
|
|
|
return log
|
|
|
|
|
|
|
|
|
def collx_to_coll_setup(
|
|
|
setup: list[problem.Dependency],
|
|
|
) -> list[problem.Dependency]:
|
|
|
"""Convert collx to coll in setups."""
|
|
|
result = []
|
|
|
for level in setup_to_levels(setup):
|
|
|
hashs = set()
|
|
|
for dep in level:
|
|
|
if dep.name == 'collx':
|
|
|
dep.name = 'coll'
|
|
|
dep.args = list(set(dep.args))
|
|
|
|
|
|
if dep.hashed() in hashs:
|
|
|
continue
|
|
|
hashs.add(dep.hashed())
|
|
|
result.append(dep)
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
def collx_to_coll(
|
|
|
setup: list[problem.Dependency],
|
|
|
aux_setup: list[problem.Dependency],
|
|
|
log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
|
|
|
) -> tuple[
|
|
|
list[problem.Dependency],
|
|
|
list[problem.Dependency],
|
|
|
list[tuple[list[problem.Dependency], list[problem.Dependency]]],
|
|
|
]:
|
|
|
"""Convert collx to coll and dedup."""
|
|
|
setup = collx_to_coll_setup(setup)
|
|
|
aux_setup = collx_to_coll_setup(aux_setup)
|
|
|
|
|
|
con_set = set([p.hashed() for p in setup + aux_setup])
|
|
|
log_, log = log, []
|
|
|
for prems, cons in log_:
|
|
|
prem_set = set()
|
|
|
prems_, prems = prems, []
|
|
|
for p in prems_:
|
|
|
if p.name == 'collx':
|
|
|
p.name = 'coll'
|
|
|
p.args = list(set(p.args))
|
|
|
if p.hashed() in prem_set:
|
|
|
continue
|
|
|
prem_set.add(p.hashed())
|
|
|
prems.append(p)
|
|
|
|
|
|
cons_, cons = cons, []
|
|
|
for c in cons_:
|
|
|
if c.name == 'collx':
|
|
|
c.name = 'coll'
|
|
|
c.args = list(set(c.args))
|
|
|
if c.hashed() in con_set:
|
|
|
continue
|
|
|
con_set.add(c.hashed())
|
|
|
cons.append(c)
|
|
|
|
|
|
if not cons or not prems:
|
|
|
continue
|
|
|
|
|
|
log.append((prems, cons))
|
|
|
|
|
|
return setup, aux_setup, log
|
|
|
|
|
|
|
|
|
def get_logs(
|
|
|
query: problem.Dependency, g: Any, merge_trivials: bool = False
|
|
|
) -> tuple[
|
|
|
list[problem.Dependency],
|
|
|
list[problem.Dependency],
|
|
|
list[tuple[list[problem.Dependency], list[problem.Dependency]]],
|
|
|
set[gm.Point],
|
|
|
]:
|
|
|
"""Given a DAG and conclusion N, return the premise, aux, proof."""
|
|
|
query = query.why_me_or_cache(g, query.level)
|
|
|
log = recursive_traceback(query)
|
|
|
log, setup, aux_setup, setup_points, _ = separate_dependency_difference(
|
|
|
query, log
|
|
|
)
|
|
|
|
|
|
setup, aux_setup, log = collx_to_coll(setup, aux_setup, log)
|
|
|
|
|
|
setup, aux_setup, log = shorten_and_shave(
|
|
|
setup, aux_setup, log, merge_trivials
|
|
|
)
|
|
|
|
|
|
return setup, aux_setup, log, setup_points
|
|
|
|
|
|
|
|
|
def shorten_and_shave(
|
|
|
setup: list[problem.Dependency],
|
|
|
aux_setup: list[problem.Dependency],
|
|
|
log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
|
|
|
merge_trivials: bool = False,
|
|
|
) -> tuple[
|
|
|
list[problem.Dependency],
|
|
|
list[problem.Dependency],
|
|
|
list[tuple[list[problem.Dependency], list[problem.Dependency]]],
|
|
|
]:
|
|
|
"""Shorten the proof by removing unused predicates."""
|
|
|
log, _ = shorten_proof(log, merge_trivials=merge_trivials)
|
|
|
|
|
|
all_prems = sum([list(prems) for prems, _ in log], [])
|
|
|
all_prems = set([p.hashed() for p in all_prems])
|
|
|
setup = [d for d in setup if d.hashed() in all_prems]
|
|
|
aux_setup = [d for d in aux_setup if d.hashed() in all_prems]
|
|
|
return setup, aux_setup, log
|
|
|
|
|
|
|
|
|
def join_prems(
|
|
|
con: problem.Dependency,
|
|
|
con2prems: dict[tuple[str, ...], list[problem.Dependency]],
|
|
|
expanded: set[tuple[str, ...]],
|
|
|
) -> list[problem.Dependency]:
|
|
|
"""Join proof steps with the same premises."""
|
|
|
h = con.hashed()
|
|
|
if h in expanded or h not in con2prems:
|
|
|
return [con]
|
|
|
|
|
|
result = []
|
|
|
for p in con2prems[h]:
|
|
|
result += join_prems(p, con2prems, expanded)
|
|
|
return result
|
|
|
|
|
|
|
|
|
def shorten_proof(
|
|
|
log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
|
|
|
merge_trivials: bool = False,
|
|
|
) -> tuple[
|
|
|
list[tuple[list[problem.Dependency], list[problem.Dependency]]],
|
|
|
dict[tuple[str, ...], list[problem.Dependency]],
|
|
|
]:
|
|
|
"""Join multiple trivials proof steps into one."""
|
|
|
pops = set()
|
|
|
con2prem = {}
|
|
|
for prems, cons in log:
|
|
|
assert len(cons) == 1
|
|
|
con = cons[0]
|
|
|
if con.rule_name == '':
|
|
|
con2prem[con.hashed()] = prems
|
|
|
elif not merge_trivials:
|
|
|
|
|
|
pops.update({p.hashed() for p in prems})
|
|
|
|
|
|
for p in pops:
|
|
|
if p in con2prem:
|
|
|
con2prem.pop(p)
|
|
|
|
|
|
expanded = set()
|
|
|
log2 = []
|
|
|
for i, (prems, cons) in enumerate(log):
|
|
|
con = cons[0]
|
|
|
if i < len(log) - 1 and con.hashed() in con2prem:
|
|
|
continue
|
|
|
|
|
|
hashs = set()
|
|
|
new_prems = []
|
|
|
|
|
|
for p in sum([join_prems(p, con2prem, expanded) for p in prems], []):
|
|
|
if p.hashed() not in hashs:
|
|
|
new_prems.append(p)
|
|
|
hashs.add(p.hashed())
|
|
|
|
|
|
log2 += [(new_prems, [con])]
|
|
|
expanded.add(con.hashed())
|
|
|
|
|
|
return log2, con2prem
|
|
|
|