Axiovora-X / backend /core /trace_back.py
ZAIDX11's picture
Add files using upload-large-folder tool
effde1c verified
# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Implements DAG-level traceback with advanced analytics, visualization, provenance, distributed computation, and extensibility.
"""
# --- Advanced Traceback Expansion Imports ---
import logging
import threading
import concurrent.futures
import queue
import time
import json
from typing import Callable, Optional, Dict, Any
import matplotlib.pyplot as plt
import networkx as nx
import uuid
# --- Advanced Traceback Analytics ---
def traceback_statistics(log: list[tuple[list[Any], list[Any]]]) -> Dict[str, Any]:
"""Compute statistics about the traceback log."""
num_steps = len(log)
num_unique_prems = len(set([p.hashed() for prems, _ in log for p in prems]))
num_unique_cons = len(set([c.hashed() for _, cons in log for c in cons]))
return {
"num_steps": num_steps,
"num_unique_prems": num_unique_prems,
"num_unique_cons": num_unique_cons,
}
def export_traceback_provenance(log: list[tuple[list[Any], list[Any]]], file_path: str):
"""Export provenance of the traceback to a JSON file."""
provenance = [
{
"prems": [p.hashed() for p in prems],
"cons": [c.hashed() for c in cons],
}
for prems, cons in log
]
with open(file_path, "w", encoding="utf-8") as f:
json.dump(provenance, f, indent=2)
# --- Visualization Utilities ---
def visualize_traceback_graph(log: list[tuple[list[Any], list[Any]]], show: bool = True, save_path: Optional[str] = None):
"""Visualize the traceback as a DAG using networkx and matplotlib."""
G = nx.DiGraph()
for prems, cons in log:
for c in cons:
for p in prems:
G.add_edge(p.hashed(), c.hashed())
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(G)
nx.draw(G, pos, with_labels=True, node_size=500, font_size=8)
if save_path:
plt.savefig(save_path)
if show:
plt.show()
# --- Parallel/Distributed Traceback Computation ---
def parallel_recursive_traceback(queries: list[Any], max_workers: int = 4) -> Dict[str, list[tuple[list[Any], list[Any]]]]:
"""Compute recursive traceback for multiple queries in parallel."""
results = {}
def worker(q):
return q.hashed(), recursive_traceback(q)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_query = {executor.submit(worker, q): q for q in queries}
for future in concurrent.futures.as_completed(future_to_query):
h, log = future.result()
results[h] = log
return results
# --- Real-Time Traceback Streaming (Websockets/Async) ---
class TracebackStreamer:
"""Streams traceback steps to listeners in real time."""
def __init__(self):
self.listeners = []
self.q = queue.Queue()
self.running = False
def add_listener(self, callback: Callable[[Any], None]):
self.listeners.append(callback)
def stream(self, log: list[tuple[list[Any], list[Any]]]):
self.running = True
def run():
for step in log:
self.q.put(step)
time.sleep(0.05)
self.running = False
threading.Thread(target=run, daemon=True).start()
while self.running or not self.q.empty():
try:
step = self.q.get(timeout=0.1)
for cb in self.listeners:
cb(step)
except queue.Empty:
continue
# --- Plugin System for Custom Trace Analyzers ---
class TracebackPlugin:
"""Base class for custom traceback analyzers."""
def analyze(self, log: list[tuple[list[Any], list[Any]]]) -> Any:
raise NotImplementedError
class TracebackPluginManager:
def __init__(self):
self.plugins: Dict[str, TracebackPlugin] = {}
def register(self, name: str, plugin: TracebackPlugin):
self.plugins[name] = plugin
def run_all(self, log: list[tuple[list[Any], list[Any]]]) -> Dict[str, Any]:
return {name: plugin.analyze(log) for name, plugin in self.plugins.items()}
# --- External Proof Engine Integration (Stub) ---
def integrate_external_prover(log: list[tuple[list[Any], list[Any]]], prover_api: Callable):
"""Send traceback steps to an external proof engine for validation or augmentation."""
for prems, cons in log:
prover_api({"prems": prems, "cons": cons})
# --- Robust Error Handling and Logging ---
def safe_traceback(query: Any) -> Optional[list[tuple[list[Any], list[Any]]]]:
try:
return recursive_traceback(query)
except Exception as e:
logging.error(f"Traceback failed: {e}", exc_info=True)
return None
# --- Test Harness and Benchmarking ---
def test_traceback_module():
import random
class DummyDep:
def __init__(self, name):
self._name = name
def hashed(self):
return self._name + str(uuid.uuid4())
@property
def rule_name(self):
return random.choice(['', 'c0', 'collx', 'coll'])
@property
def why(self):
return []
def remove_loop(self):
return self
# Generate dummy log
queries = [DummyDep(f"Q{i}") for i in range(5)]
logs = parallel_recursive_traceback(queries)
for h, log in logs.items():
stats = traceback_statistics(log)
print(f"Traceback {h}: {stats}")
visualize_traceback_graph(log, show=False)
# Test streaming
streamer = TracebackStreamer()
streamer.add_listener(lambda step: print(f"Streamed step: {step}"))
for log in logs.values():
streamer.stream(log)
# Test plugin system
class StepCountPlugin(TracebackPlugin):
def analyze(self, log):
return len(log)
pm = TracebackPluginManager()
pm.register("step_count", StepCountPlugin())
for log in logs.values():
print("Plugin results:", pm.run_all(log))
if __name__ == "__main__":
test_traceback_module()
from typing import Any
import geometry as gm
import pretty as pt
import problem
pretty = pt.pretty
def point_levels(
setup: list[problem.Dependency], existing_points: list[gm.Point]
) -> list[tuple[set[gm.Point], list[problem.Dependency]]]:
"""Reformat setup into levels of point constructions."""
levels = []
for con in setup:
plevel = max([p.plevel for p in con.args if isinstance(p, gm.Point)])
while len(levels) - 1 < plevel:
levels.append((set(), []))
for p in con.args:
if not isinstance(p, gm.Point):
continue
if existing_points and p in existing_points:
continue
levels[p.plevel][0].add(p)
cons = levels[plevel][1]
cons.append(con)
return [(p, c) for p, c in levels if p or c]
def point_log(
setup: list[problem.Dependency],
ref_id: dict[tuple[str, ...], int],
existing_points=list[gm.Point],
) -> list[tuple[list[gm.Point], list[problem.Dependency]]]:
"""Reformat setup into groups of point constructions."""
log = []
levels = point_levels(setup, existing_points)
for points, cons in levels:
for con in cons:
if con.hashed() not in ref_id:
ref_id[con.hashed()] = len(ref_id)
log.append((points, cons))
return log
def setup_to_levels(
setup: list[problem.Dependency],
) -> list[list[problem.Dependency]]:
"""Reformat setup into levels of point constructions."""
levels = []
for d in setup:
plevel = max([p.plevel for p in d.args if isinstance(p, gm.Point)])
while len(levels) - 1 < plevel:
levels.append([])
levels[plevel].append(d)
levels = [lvl for lvl in levels if lvl]
return levels
def separate_dependency_difference(
query: problem.Dependency,
log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
) -> tuple[
list[tuple[list[problem.Dependency], list[problem.Dependency]]],
list[problem.Dependency],
list[problem.Dependency],
set[gm.Point],
set[gm.Point],
]:
"""Identify and separate the dependency difference."""
setup = []
log_, log = log, []
for prems, cons in log_:
if not prems:
setup.extend(cons)
continue
cons_ = []
for con in cons:
if con.rule_name == 'c0':
setup.append(con)
else:
cons_.append(con)
if not cons_:
continue
prems = [p for p in prems if p.name != 'ind']
log.append((prems, cons_))
points = set(query.args)
queue = list(query.args)
i = 0
while i < len(queue):
q = queue[i]
i += 1
if not isinstance(q, gm.Point):
continue
for p in q.rely_on:
if p not in points:
points.add(p)
queue.append(p)
setup_, setup, aux_setup, aux_points = setup, [], [], set()
for con in setup_:
if con.name == 'ind':
continue
elif any([p not in points for p in con.args if isinstance(p, gm.Point)]):
aux_setup.append(con)
aux_points.update(
[p for p in con.args if isinstance(p, gm.Point) and p not in points]
)
else:
setup.append(con)
return log, setup, aux_setup, points, aux_points
def recursive_traceback(
query: problem.Dependency,
) -> list[tuple[list[problem.Dependency], list[problem.Dependency]]]:
"""Recursively traceback from the query, i.e. the conclusion."""
visited = set()
log = []
stack = []
def read(q: problem.Dependency) -> None:
q = q.remove_loop()
hashed = q.hashed()
if hashed in visited:
return
if hashed[0] in ['ncoll', 'npara', 'nperp', 'diff', 'sameside']:
return
nonlocal stack
stack.append(hashed)
prems = []
if q.rule_name != problem.CONSTRUCTION_RULE:
all_deps = []
dep_names = set()
for d in q.why:
if d.hashed() in dep_names:
continue
dep_names.add(d.hashed())
all_deps.append(d)
for d in all_deps:
h = d.hashed()
if h not in visited:
read(d)
if h in visited:
prems.append(d)
visited.add(hashed)
hashs = sorted([d.hashed() for d in prems])
found = False
for ps, qs in log:
if sorted([d.hashed() for d in ps]) == hashs:
qs += [q]
found = True
break
if not found:
log.append((prems, [q]))
stack.pop(-1)
read(query)
# post process log: separate multi-conclusion lines
log_, log = log, []
for ps, qs in log_:
for q in qs:
log.append((ps, [q]))
return log
def collx_to_coll_setup(
setup: list[problem.Dependency],
) -> list[problem.Dependency]:
"""Convert collx to coll in setups."""
result = []
for level in setup_to_levels(setup):
hashs = set()
for dep in level:
if dep.name == 'collx':
dep.name = 'coll'
dep.args = list(set(dep.args))
if dep.hashed() in hashs:
continue
hashs.add(dep.hashed())
result.append(dep)
return result
def collx_to_coll(
setup: list[problem.Dependency],
aux_setup: list[problem.Dependency],
log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
) -> tuple[
list[problem.Dependency],
list[problem.Dependency],
list[tuple[list[problem.Dependency], list[problem.Dependency]]],
]:
"""Convert collx to coll and dedup."""
setup = collx_to_coll_setup(setup)
aux_setup = collx_to_coll_setup(aux_setup)
con_set = set([p.hashed() for p in setup + aux_setup])
log_, log = log, []
for prems, cons in log_:
prem_set = set()
prems_, prems = prems, []
for p in prems_:
if p.name == 'collx':
p.name = 'coll'
p.args = list(set(p.args))
if p.hashed() in prem_set:
continue
prem_set.add(p.hashed())
prems.append(p)
cons_, cons = cons, []
for c in cons_:
if c.name == 'collx':
c.name = 'coll'
c.args = list(set(c.args))
if c.hashed() in con_set:
continue
con_set.add(c.hashed())
cons.append(c)
if not cons or not prems:
continue
log.append((prems, cons))
return setup, aux_setup, log
def get_logs(
query: problem.Dependency, g: Any, merge_trivials: bool = False
) -> tuple[
list[problem.Dependency],
list[problem.Dependency],
list[tuple[list[problem.Dependency], list[problem.Dependency]]],
set[gm.Point],
]:
"""Given a DAG and conclusion N, return the premise, aux, proof."""
query = query.why_me_or_cache(g, query.level)
log = recursive_traceback(query)
log, setup, aux_setup, setup_points, _ = separate_dependency_difference(
query, log
)
setup, aux_setup, log = collx_to_coll(setup, aux_setup, log)
setup, aux_setup, log = shorten_and_shave(
setup, aux_setup, log, merge_trivials
)
return setup, aux_setup, log, setup_points
def shorten_and_shave(
setup: list[problem.Dependency],
aux_setup: list[problem.Dependency],
log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
merge_trivials: bool = False,
) -> tuple[
list[problem.Dependency],
list[problem.Dependency],
list[tuple[list[problem.Dependency], list[problem.Dependency]]],
]:
"""Shorten the proof by removing unused predicates."""
log, _ = shorten_proof(log, merge_trivials=merge_trivials)
all_prems = sum([list(prems) for prems, _ in log], [])
all_prems = set([p.hashed() for p in all_prems])
setup = [d for d in setup if d.hashed() in all_prems]
aux_setup = [d for d in aux_setup if d.hashed() in all_prems]
return setup, aux_setup, log
def join_prems(
con: problem.Dependency,
con2prems: dict[tuple[str, ...], list[problem.Dependency]],
expanded: set[tuple[str, ...]],
) -> list[problem.Dependency]:
"""Join proof steps with the same premises."""
h = con.hashed()
if h in expanded or h not in con2prems:
return [con]
result = []
for p in con2prems[h]:
result += join_prems(p, con2prems, expanded)
return result
def shorten_proof(
log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
merge_trivials: bool = False,
) -> tuple[
list[tuple[list[problem.Dependency], list[problem.Dependency]]],
dict[tuple[str, ...], list[problem.Dependency]],
]:
"""Join multiple trivials proof steps into one."""
pops = set()
con2prem = {}
for prems, cons in log:
assert len(cons) == 1
con = cons[0]
if con.rule_name == '': # pylint: disable=g-explicit-bool-comparison
con2prem[con.hashed()] = prems
elif not merge_trivials:
# except for the ones that are premises to non-trivial steps.
pops.update({p.hashed() for p in prems})
for p in pops:
if p in con2prem:
con2prem.pop(p)
expanded = set()
log2 = []
for i, (prems, cons) in enumerate(log):
con = cons[0]
if i < len(log) - 1 and con.hashed() in con2prem:
continue
hashs = set()
new_prems = []
for p in sum([join_prems(p, con2prem, expanded) for p in prems], []):
if p.hashed() not in hashs:
new_prems.append(p)
hashs.add(p.hashed())
log2 += [(new_prems, [con])]
expanded.add(con.hashed())
return log2, con2prem