Spaces:
Sleeping
Sleeping
added a synthesizer on top
Browse files- __pycache__/abstract_syntax_tree.cpython-39.pyc +0 -0
- __pycache__/python_embedded_rasp.cpython-39.pyc +0 -0
- __pycache__/rasp_synthesizer.cpython-39.pyc +0 -0
- __pycache__/utils.cpython-39.pyc +0 -0
- abstract_syntax_tree.py +72 -0
- app.py +18 -0
- comp_flows/( tokens_int . 1 )(.1. 2.).pdf +0 -0
- outtest.txt +36 -0
- python_embedded_rasp.py +308 -0
- rasp_synthesizer.py +257 -0
- reverse-viz.ipynb +0 -0
- testouts.txt +55 -0
- tracr/__pycache__/__init__.cpython-39.pyc +0 -0
- tracr/compiler/__pycache__/__init__.cpython-39.pyc +0 -0
- tracr/compiler/__pycache__/assemble.cpython-39.pyc +0 -0
- tracr/compiler/__pycache__/basis_inference.cpython-39.pyc +0 -0
- tracr/compiler/__pycache__/compiling.cpython-39.pyc +0 -0
- tracr/compiler/__pycache__/craft_graph_to_model.cpython-39.pyc +0 -0
- tracr/compiler/__pycache__/craft_model_to_transformer.cpython-39.pyc +0 -0
- tracr/compiler/__pycache__/expr_to_craft_graph.cpython-39.pyc +0 -0
- tracr/compiler/__pycache__/nodes.cpython-39.pyc +0 -0
- tracr/compiler/__pycache__/rasp_to_graph.cpython-39.pyc +0 -0
- tracr/craft/__pycache__/__init__.cpython-39.pyc +0 -0
- tracr/craft/__pycache__/bases.cpython-39.pyc +0 -0
- tracr/craft/__pycache__/transformers.cpython-39.pyc +0 -0
- tracr/craft/__pycache__/vectorspace_fns.cpython-39.pyc +0 -0
- tracr/craft/chamber/__pycache__/__init__.cpython-39.pyc +0 -0
- tracr/craft/chamber/__pycache__/categorical_attn.cpython-39.pyc +0 -0
- tracr/craft/chamber/__pycache__/categorical_mlp.cpython-39.pyc +0 -0
- tracr/craft/chamber/__pycache__/numerical_mlp.cpython-39.pyc +0 -0
- tracr/craft/chamber/__pycache__/selector_width.cpython-39.pyc +0 -0
- tracr/rasp/__pycache__/__init__.cpython-39.pyc +0 -0
- tracr/rasp/__pycache__/rasp.cpython-39.pyc +0 -0
- tracr/transformer/__pycache__/__init__.cpython-39.pyc +0 -0
- tracr/transformer/__pycache__/attention.cpython-39.pyc +0 -0
- tracr/transformer/__pycache__/encoder.cpython-39.pyc +0 -0
- tracr/transformer/__pycache__/model.cpython-39.pyc +0 -0
- tracr/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- tracr/utils/__pycache__/errors.cpython-39.pyc +0 -0
- utils.py +80 -0
__pycache__/abstract_syntax_tree.cpython-39.pyc
ADDED
|
Binary file (2.94 kB). View file
|
|
|
__pycache__/python_embedded_rasp.cpython-39.pyc
ADDED
|
Binary file (9.03 kB). View file
|
|
|
__pycache__/rasp_synthesizer.cpython-39.pyc
ADDED
|
Binary file (9.09 kB). View file
|
|
|
__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (1.69 kB). View file
|
|
|
abstract_syntax_tree.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
ABSTRACT SYNTAX TREE
|
| 3 |
+
This file contains the Python class that represents programs created by our rasp synthesizer.
|
| 4 |
+
'''
|
| 5 |
+
from utils import *
|
| 6 |
+
|
| 7 |
+
class OperatorNode:
|
| 8 |
+
'''
|
| 9 |
+
Class to represent operator nodes (i.e., an operator and its operands) as an AST.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
operator (object): operator object (e.g., Select, Aggregate, etc.)
|
| 13 |
+
children (list): list of children nodes (operands)
|
| 14 |
+
|
| 15 |
+
Example:
|
| 16 |
+
select_node: OperatorNode(Select(), [Tokens(), Tokens(), Equal()])
|
| 17 |
+
select_node.str() = "select(tokens, tokens, ==)"
|
| 18 |
+
select_node.evaluate("hi") = [[1, 0], [0, 1]]
|
| 19 |
+
select_node.to_python() = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ)
|
| 20 |
+
'''
|
| 21 |
+
def __init__(self, operator, children):
|
| 22 |
+
self.operator = operator
|
| 23 |
+
self.children = children
|
| 24 |
+
self.weight = operator.weight + sum([child.weight for child in children])
|
| 25 |
+
self.return_type = operator.return_type
|
| 26 |
+
|
| 27 |
+
def str(self):
|
| 28 |
+
if len(self.children) != self.operator.n_args:
|
| 29 |
+
raise ValueError("Improper number of arguments for operator.")
|
| 30 |
+
operand_strings = [child.str() for child in self.children]
|
| 31 |
+
return f"({self.operator.str(*operand_strings)})"
|
| 32 |
+
|
| 33 |
+
def evaluate(self, input=None):
|
| 34 |
+
'''
|
| 35 |
+
Directly evaluate the python translation.
|
| 36 |
+
'''
|
| 37 |
+
exe = self.to_python()
|
| 38 |
+
return exe(input)
|
| 39 |
+
|
| 40 |
+
# DEPRECATED VERSION: uses the actual rasp repl
|
| 41 |
+
# exe = f"({self.str()})" + f"({repr(input)});".replace("'", "\"")
|
| 42 |
+
# return run_repl(exe)
|
| 43 |
+
|
| 44 |
+
def to_python(self):
|
| 45 |
+
if len(self.children) != self.operator.n_args:
|
| 46 |
+
raise ValueError("Improper number of arguments for operator.")
|
| 47 |
+
operands = [child.to_python() for child in self.children]
|
| 48 |
+
return self.operator.to_python(*operands)
|
| 49 |
+
|
| 50 |
+
'''
|
| 51 |
+
TESTING
|
| 52 |
+
'''
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
from python_embedded_rasp import *
|
| 55 |
+
from tracr.rasp import rasp
|
| 56 |
+
|
| 57 |
+
select_op = OperatorNode(Select(), [Tokens(), Tokens(), Equal()]) # wait should children be operators or operator nodes? maybe can be either?
|
| 58 |
+
assert (select_op.weight == 4)
|
| 59 |
+
|
| 60 |
+
select_op_str = select_op.str()
|
| 61 |
+
actual_so_str = "(select(tokens, tokens, ==))"
|
| 62 |
+
assert select_op_str == actual_so_str
|
| 63 |
+
|
| 64 |
+
select_op_res = select_op.evaluate("hi")
|
| 65 |
+
actual_so_res = [[1, 0],[0, 1]]
|
| 66 |
+
assert select_op_res == actual_so_res
|
| 67 |
+
|
| 68 |
+
select_op_python = select_op.to_python()
|
| 69 |
+
actual_so_python = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ)
|
| 70 |
+
assert type(select_op_python) == type(actual_so_python)
|
| 71 |
+
|
| 72 |
+
print("all tests passed hooray!")
|
app.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
For future reference with downloading model files:
|
| 3 |
+
|
| 4 |
+
import streamlit as st
|
| 5 |
+
import pickle
|
| 6 |
+
import base64
|
| 7 |
+
|
| 8 |
+
x = {"my": "data"}
|
| 9 |
+
|
| 10 |
+
def download_model(model):
|
| 11 |
+
output_model = pickle.dumps(model)
|
| 12 |
+
b64 = base64.b64encode(output_model).decode()
|
| 13 |
+
href = f'<a href="data:file/output_model;base64,{b64}" download="myfile.pkl">Download Trained Model .pkl File</a>'
|
| 14 |
+
st.markdown(href, unsafe_allow_html=True)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
download_model(x)
|
| 18 |
+
'''
|
comp_flows/( tokens_int . 1 )(.1. 2.).pdf
ADDED
|
Binary file (19.2 kB). View file
|
|
|
outtest.txt
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Received the following input and output examples:
|
| 2 |
+
[(['h', 'e', 'l', 'l', 'o'], [1, 1, 2, 2, 1])]
|
| 3 |
+
Running synthesizer with
|
| 4 |
+
Vocab: {'o', 'e', 'h', 'l'}
|
| 5 |
+
Max sequence length: 5
|
| 6 |
+
Max weight: 25
|
| 7 |
+
(indices - indices)
|
| 8 |
+
[[0, 0, 0, 0, 0]]
|
| 9 |
+
(indices - 0)
|
| 10 |
+
[[0, 1, 2, 3, 4]]
|
| 11 |
+
(indices - 1)
|
| 12 |
+
[[-1, 0, 1, 2, 3]]
|
| 13 |
+
(0 - indices)
|
| 14 |
+
[[0, -1, -2, -3, -4]]
|
| 15 |
+
(1 - indices)
|
| 16 |
+
[[1, 0, -1, -2, -3]]
|
| 17 |
+
(select(tokens, tokens, ==))
|
| 18 |
+
[[[True, False, False, False, False], [False, True, False, False, False], [False, False, True, True, False], [False, False, True, True, False], [False, False, False, False, True]]]
|
| 19 |
+
(select(tokens, tokens, true))
|
| 20 |
+
[[[True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True]]]
|
| 21 |
+
(select(tokens, indices, ==))
|
| 22 |
+
[[[False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False]]]
|
| 23 |
+
(select(tokens, indices, true))
|
| 24 |
+
[[[True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True]]]
|
| 25 |
+
(select(indices, tokens, ==))
|
| 26 |
+
[[[False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False]]]
|
| 27 |
+
(select(indices, tokens, true))
|
| 28 |
+
[[[True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True]]]
|
| 29 |
+
(select(indices, indices, ==))
|
| 30 |
+
[[[True, False, False, False, False], [False, True, False, False, False], [False, False, True, False, False], [False, False, False, True, False], [False, False, False, False, True]]]
|
| 31 |
+
(select(indices, indices, true))
|
| 32 |
+
[[[True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True]]]
|
| 33 |
+
(select_width((select(tokens, tokens, ==))))
|
| 34 |
+
[[1, 1, 2, 2, 1]]
|
| 35 |
+
The following program has been compiled to a transformer with 1 layer(s):
|
| 36 |
+
(select_width((select(tokens, tokens, ==))))
|
python_embedded_rasp.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
RASP OPERATORS THAT ARE SUPPORTED BY TRACR'S PYTHON EMBEDDING
|
| 3 |
+
This file contains Python classes that define the rasp operators supported by TRACR's python embedding of the langauge.
|
| 4 |
+
This is subset of everything that TRACR supports in python, due to project time constraints.
|
| 5 |
+
'''
|
| 6 |
+
import random
|
| 7 |
+
from typing import (Any, Callable, Dict, Generic, List, Mapping, Optional,
|
| 8 |
+
Sequence, TypeVar, Union)
|
| 9 |
+
from tracr.rasp import rasp
|
| 10 |
+
import subprocess
|
| 11 |
+
import time
|
| 12 |
+
|
| 13 |
+
'''
|
| 14 |
+
CLASS DEFINITIONS
|
| 15 |
+
'''
|
| 16 |
+
class Tokens:
|
| 17 |
+
'''
|
| 18 |
+
Tokens constant.
|
| 19 |
+
'''
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self.n_args = 0
|
| 22 |
+
self.arg_types = []
|
| 23 |
+
self.return_type = rasp.SOp
|
| 24 |
+
self.weight = 1
|
| 25 |
+
|
| 26 |
+
def to_python(self):
|
| 27 |
+
# return an object that can be compiled into a TRACR transformer
|
| 28 |
+
# arguments should be python objects
|
| 29 |
+
return rasp.tokens
|
| 30 |
+
|
| 31 |
+
def str(self):
|
| 32 |
+
# represent rasp operator in string form
|
| 33 |
+
# expects arguments to be strings
|
| 34 |
+
return "tokens"
|
| 35 |
+
|
| 36 |
+
class Indices:
|
| 37 |
+
def __init__(self):
|
| 38 |
+
self.n_args = 0
|
| 39 |
+
self.arg_types = []
|
| 40 |
+
self.return_type = rasp.SOp
|
| 41 |
+
self.weight = 1
|
| 42 |
+
|
| 43 |
+
def to_python(self):
|
| 44 |
+
# return an object that can be compiled into a TRACR transformer
|
| 45 |
+
# arguments should be python objects
|
| 46 |
+
return rasp.indices
|
| 47 |
+
|
| 48 |
+
def str(self):
|
| 49 |
+
# represent rasp operator in string form
|
| 50 |
+
# expects arguments to be strings
|
| 51 |
+
return "indices"
|
| 52 |
+
|
| 53 |
+
class Zero:
|
| 54 |
+
def __init__(self):
|
| 55 |
+
self.n_args = 0
|
| 56 |
+
self.arg_types = []
|
| 57 |
+
self.return_type = int
|
| 58 |
+
self.weight = 1
|
| 59 |
+
|
| 60 |
+
def to_python(self):
|
| 61 |
+
# return an object that can be compiled into a TRACR transformer
|
| 62 |
+
# arguments should be python objects
|
| 63 |
+
return 0
|
| 64 |
+
|
| 65 |
+
def str(self):
|
| 66 |
+
# represent rasp operator in string form
|
| 67 |
+
# expects arguments to be strings
|
| 68 |
+
return "0"
|
| 69 |
+
|
| 70 |
+
class One:
|
| 71 |
+
def __init__(self):
|
| 72 |
+
self.n_args = 0
|
| 73 |
+
self.arg_types = []
|
| 74 |
+
self.return_type = int
|
| 75 |
+
self.weight = 1
|
| 76 |
+
|
| 77 |
+
def to_python(self):
|
| 78 |
+
# return an object that can be compiled into a TRACR transformer
|
| 79 |
+
# arguments should be python objects
|
| 80 |
+
return 1
|
| 81 |
+
|
| 82 |
+
def str(self):
|
| 83 |
+
# represent rasp operator in string form
|
| 84 |
+
# expects arguments to be strings
|
| 85 |
+
return "1"
|
| 86 |
+
|
| 87 |
+
class Equal:
|
| 88 |
+
'''
|
| 89 |
+
Comparison Equal constant.
|
| 90 |
+
'''
|
| 91 |
+
def __init__(self):
|
| 92 |
+
self.n_args = 0
|
| 93 |
+
self.arg_types = []
|
| 94 |
+
self.return_type = rasp.Predicate
|
| 95 |
+
self.weight = 1
|
| 96 |
+
|
| 97 |
+
def to_python(self):
|
| 98 |
+
# return an object that can be compiled into a TRACR transformer
|
| 99 |
+
# arguments should be python objects
|
| 100 |
+
return rasp.Comparison.EQ
|
| 101 |
+
|
| 102 |
+
def str(self):
|
| 103 |
+
# represent rasp operator in string form
|
| 104 |
+
# expects arguments to be strings
|
| 105 |
+
return "=="
|
| 106 |
+
|
| 107 |
+
class GT:
|
| 108 |
+
'''
|
| 109 |
+
Greater Than comparison operator.
|
| 110 |
+
'''
|
| 111 |
+
pass
|
| 112 |
+
|
| 113 |
+
class LT:
|
| 114 |
+
'''
|
| 115 |
+
Less Than comparison operator
|
| 116 |
+
'''
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
class LEQ:
|
| 120 |
+
pass
|
| 121 |
+
|
| 122 |
+
class GEQ:
|
| 123 |
+
pass
|
| 124 |
+
|
| 125 |
+
class TRUE:
|
| 126 |
+
'''
|
| 127 |
+
Comparison True constant.
|
| 128 |
+
'''
|
| 129 |
+
def __init__(self):
|
| 130 |
+
self.n_args = 0
|
| 131 |
+
self.arg_types = []
|
| 132 |
+
self.return_type = rasp.Predicate
|
| 133 |
+
self.weight = 1
|
| 134 |
+
|
| 135 |
+
def to_python(self):
|
| 136 |
+
# return an object that can be compiled into a TRACR transformer
|
| 137 |
+
# arguments should be python objects
|
| 138 |
+
return rasp.Comparison.TRUE
|
| 139 |
+
|
| 140 |
+
def str(self):
|
| 141 |
+
# represent rasp operator in string form
|
| 142 |
+
# expects arguments to be strings
|
| 143 |
+
return "true"
|
| 144 |
+
|
| 145 |
+
class FALSE:
|
| 146 |
+
pass
|
| 147 |
+
|
| 148 |
+
class Add:
|
| 149 |
+
'''
|
| 150 |
+
Element-wise.
|
| 151 |
+
Input can be either int, float or s-op.
|
| 152 |
+
'''
|
| 153 |
+
pass
|
| 154 |
+
|
| 155 |
+
class Subtract:
|
| 156 |
+
'''
|
| 157 |
+
Element-wise.
|
| 158 |
+
Input can be either int, float or s-op.
|
| 159 |
+
'''
|
| 160 |
+
def __init__(self):
|
| 161 |
+
self.n_args = 2
|
| 162 |
+
self.arg_types = [Union[rasp.SOp, float, int], Union[rasp.SOp, float, int]]
|
| 163 |
+
self.return_type = Union[rasp.SOp, int, float]
|
| 164 |
+
self.weight = 1
|
| 165 |
+
|
| 166 |
+
def to_python(self, x, y):
|
| 167 |
+
# return an object that can be compiled into a TRACR transformer
|
| 168 |
+
# arguments should be python objects
|
| 169 |
+
if type(x) == type(rasp.tokens):
|
| 170 |
+
return None
|
| 171 |
+
if type(y) == type(rasp.tokens):
|
| 172 |
+
return None
|
| 173 |
+
return x - y
|
| 174 |
+
|
| 175 |
+
def str(self, x, y):
|
| 176 |
+
# represent rasp operator in string form
|
| 177 |
+
# expects arguments to be strings
|
| 178 |
+
return f"{x} - {y}"
|
| 179 |
+
|
| 180 |
+
class Mult:
|
| 181 |
+
'''
|
| 182 |
+
Element-wise.
|
| 183 |
+
Input can be either int, float or s-op.
|
| 184 |
+
'''
|
| 185 |
+
pass
|
| 186 |
+
|
| 187 |
+
class Divide:
|
| 188 |
+
'''
|
| 189 |
+
Element-wise.
|
| 190 |
+
Input can be either int, float or s-op.
|
| 191 |
+
'''
|
| 192 |
+
pass
|
| 193 |
+
|
| 194 |
+
class Fill:
|
| 195 |
+
'''
|
| 196 |
+
Given fill value and length, returns Sop of that length with that fill value.
|
| 197 |
+
Fill value can be int, float, or char.
|
| 198 |
+
Length must be a positive integer.
|
| 199 |
+
'''
|
| 200 |
+
pass
|
| 201 |
+
|
| 202 |
+
class SelectorAnd:
|
| 203 |
+
'''
|
| 204 |
+
Input can be bool or s-op.
|
| 205 |
+
'''
|
| 206 |
+
pass
|
| 207 |
+
|
| 208 |
+
class SelectorOr:
|
| 209 |
+
'''
|
| 210 |
+
Input can be bool or s-op.
|
| 211 |
+
'''
|
| 212 |
+
pass
|
| 213 |
+
|
| 214 |
+
class SelectorNot:
|
| 215 |
+
'''
|
| 216 |
+
Input is an s-op of bools. (Or bool-convertible values.)
|
| 217 |
+
'''
|
| 218 |
+
pass
|
| 219 |
+
|
| 220 |
+
class Select:
|
| 221 |
+
'''
|
| 222 |
+
Select operator.
|
| 223 |
+
'''
|
| 224 |
+
def __init__(self):
|
| 225 |
+
self.n_args = 3
|
| 226 |
+
self.arg_types = [rasp.SOp, rasp.SOp, rasp.Predicate]
|
| 227 |
+
self.return_type = rasp.Selector
|
| 228 |
+
self.weight = 1
|
| 229 |
+
|
| 230 |
+
def to_python(self, sop1, sop2, comp):
|
| 231 |
+
# return an object that can be compiled into a TRACR transformer
|
| 232 |
+
# arguments should be python objects
|
| 233 |
+
return rasp.Select(sop1, sop2, comp)
|
| 234 |
+
|
| 235 |
+
def str(self, sop1, sop2, comp):
|
| 236 |
+
# represent rasp operator in string form
|
| 237 |
+
# expects arguments to be strings
|
| 238 |
+
return f"select({sop1}, {sop2}, {comp})"
|
| 239 |
+
|
| 240 |
+
class Aggregate:
|
| 241 |
+
'''
|
| 242 |
+
The Aggregate operator.
|
| 243 |
+
'''
|
| 244 |
+
def __init__(self):
|
| 245 |
+
self.n_args = 2
|
| 246 |
+
self.arg_types = [rasp.Selector, rasp.SOp]
|
| 247 |
+
self.return_type = rasp.SOp
|
| 248 |
+
self.weight = 1
|
| 249 |
+
|
| 250 |
+
def to_python(self, sel, sop):
|
| 251 |
+
# return an object that can be compiled into a TRACR transformer
|
| 252 |
+
# arguments should be python objects
|
| 253 |
+
return rasp.Aggregate(sel, sop)
|
| 254 |
+
|
| 255 |
+
def str(self, sel, sop):
|
| 256 |
+
# represent rasp operator in string form
|
| 257 |
+
# expects arguments to be strings
|
| 258 |
+
return f"aggregate({sel}, {sop})"
|
| 259 |
+
|
| 260 |
+
class SelectorWidth:
|
| 261 |
+
'''
|
| 262 |
+
The selector_width operator.
|
| 263 |
+
'''
|
| 264 |
+
def __init__(self):
|
| 265 |
+
self.n_args = 1
|
| 266 |
+
self.arg_types = [rasp.Selector]
|
| 267 |
+
self.return_type = rasp.SOp
|
| 268 |
+
self.weight = 1
|
| 269 |
+
|
| 270 |
+
def to_python(self, sel):
|
| 271 |
+
# return an object that can be compiled into a TRACR transformer
|
| 272 |
+
# arguments should be python objects
|
| 273 |
+
return rasp.SelectorWidth(sel)
|
| 274 |
+
|
| 275 |
+
def str(self, sel):
|
| 276 |
+
# represent rasp operator in string form
|
| 277 |
+
# expects arguments to be strings
|
| 278 |
+
return f"select_width({sel})"
|
| 279 |
+
|
| 280 |
+
'''
|
| 281 |
+
GLOBAL CONSTANTS
|
| 282 |
+
'''
|
| 283 |
+
|
| 284 |
+
# define operators
|
| 285 |
+
rasp_operators = [Select(), SelectorWidth(), Aggregate(), Subtract()]
|
| 286 |
+
rasp_consts = [Tokens(), Tokens(), Equal(), TRUE(), Indices(), Indices(), Zero(), One()]
|
| 287 |
+
'''
|
| 288 |
+
TESTING
|
| 289 |
+
'''
|
| 290 |
+
if __name__ == "__main__":
|
| 291 |
+
test_select = Select()
|
| 292 |
+
|
| 293 |
+
test_select_python = test_select.to_python(Tokens().to_python(), Tokens().to_python(), Equal().to_python())
|
| 294 |
+
actual_ts_python = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ)
|
| 295 |
+
assert type(Tokens().to_python()) == type(rasp.tokens)
|
| 296 |
+
assert type(Equal().to_python() == type(rasp.Comparison.EQ))
|
| 297 |
+
assert type(test_select_python) == type(actual_ts_python)
|
| 298 |
+
|
| 299 |
+
test_select_string = test_select.str(Tokens().str(), Tokens().str(), Equal().str())
|
| 300 |
+
actual_ts_string = "select(tokens, tokens, ==)"
|
| 301 |
+
assert(test_select_string == actual_ts_string)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
test_aggregate = Aggregate()
|
| 305 |
+
print(rasp.Aggregate(rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ), rasp.tokens)("hi"))
|
| 306 |
+
|
| 307 |
+
print("all tests passed hooray!")
|
| 308 |
+
|
rasp_synthesizer.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
BOTTOM-UP ENUMERATIVE SYTHESIS FOR RASP
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python rasp_synthesis.py --examples
|
| 6 |
+
'''
|
| 7 |
+
import numpy as np
|
| 8 |
+
import argparse
|
| 9 |
+
import itertools
|
| 10 |
+
import time
|
| 11 |
+
import ast
|
| 12 |
+
import re
|
| 13 |
+
from tracr.compiler import compiling
|
| 14 |
+
from typing import get_args
|
| 15 |
+
import inspect
|
| 16 |
+
|
| 17 |
+
from abstract_syntax_tree import *
|
| 18 |
+
from python_embedded_rasp import *
|
| 19 |
+
|
| 20 |
+
# PARSE ARGUMENTS
|
| 21 |
+
def parse_args():
|
| 22 |
+
'''
|
| 23 |
+
Parse command line arguments.
|
| 24 |
+
'''
|
| 25 |
+
parser = argparse.ArgumentParser(description="Bottom-up enumerative synthesis for RASP.")
|
| 26 |
+
parser.add_argument('--examples', required=True, help="input/output sequence examples for synthesis")
|
| 27 |
+
parser.add_argument('--max_weight', type=int, required=False, default=10, help="Maximum weight of programs to consider before terminating search.")
|
| 28 |
+
args = parser.parse_args()
|
| 29 |
+
return args
|
| 30 |
+
|
| 31 |
+
# ANALYZE EXAMPLES
|
| 32 |
+
def analyze_examples(inputs):
|
| 33 |
+
'''
|
| 34 |
+
Returns a list of unique (input_sequence, output_sequence) tuples of proper python types.
|
| 35 |
+
Ensures each example is only numeric values or only char values.
|
| 36 |
+
Returns useful constants given the input examples.
|
| 37 |
+
'''
|
| 38 |
+
example_ins = []
|
| 39 |
+
example_outs = []
|
| 40 |
+
try:
|
| 41 |
+
# Safely evaluate the string to a Python object
|
| 42 |
+
examples_lst = ast.literal_eval(inputs)
|
| 43 |
+
except (SyntaxError, ValueError) as e:
|
| 44 |
+
raise argparse.ArgumentTypeError(f"Invalid examples format: {e}")
|
| 45 |
+
|
| 46 |
+
if not isinstance(examples_lst, list):
|
| 47 |
+
raise ValueError("Input should be a list.")
|
| 48 |
+
for ex in examples_lst:
|
| 49 |
+
try:
|
| 50 |
+
ins, outs = ex[0], ex[1]
|
| 51 |
+
except:
|
| 52 |
+
raise argparse.ArgumentTypeError(f"Invalid examples format.")
|
| 53 |
+
|
| 54 |
+
def same_legal_type(lst):
|
| 55 |
+
return (all(isinstance(x, int) for x in lst) or
|
| 56 |
+
all(isinstance(x, float) for x in lst) or
|
| 57 |
+
all(isinstance(x, bool) for x in lst) or
|
| 58 |
+
all(isinstance(x, str) for x in lst))
|
| 59 |
+
|
| 60 |
+
if same_legal_type(ins) and same_legal_type(outs):
|
| 61 |
+
example_ins.append(ins)
|
| 62 |
+
example_outs.append(outs)
|
| 63 |
+
continue
|
| 64 |
+
raise argparse.ArgumentTypeError(f"Each example must have consistent types. Expected inputs to have type {first_in_type} and outputs to have {first_out_type} but instead inputs have types {[type(x) for x in ins]} and outputs have types {[type(x) for x in outs]}")
|
| 65 |
+
|
| 66 |
+
return example_ins, example_outs
|
| 67 |
+
|
| 68 |
+
# GET VOCABULARY
|
| 69 |
+
def get_vocabulary(examples):
|
| 70 |
+
'''
|
| 71 |
+
Returns vocabulary for later compiling the RASP model.
|
| 72 |
+
'''
|
| 73 |
+
vocab = []
|
| 74 |
+
for ex in examples:
|
| 75 |
+
ins, outs = ex[0], ex[1]
|
| 76 |
+
vocab.extend([obj for obj in ins])
|
| 77 |
+
return set(vocab)
|
| 78 |
+
|
| 79 |
+
# CHECK OBSERVATIONAL EQUIVALENCE
|
| 80 |
+
def check_obs_equivalence(examples, program_a, program_b):
|
| 81 |
+
try:
|
| 82 |
+
inputs = [example[0] for example in examples]
|
| 83 |
+
a_output = None
|
| 84 |
+
b_output = None
|
| 85 |
+
if program_a not in rasp_consts:
|
| 86 |
+
a_output = [program_a.evaluate(input) for input in inputs]
|
| 87 |
+
if program_b not in rasp_consts:
|
| 88 |
+
b_output = [program_b.evaluate(input) for input in inputs]
|
| 89 |
+
except:
|
| 90 |
+
return True # force the synthesizer to not consider this program
|
| 91 |
+
|
| 92 |
+
return a_output == b_output
|
| 93 |
+
|
| 94 |
+
# CHECK CORRECTNESS
|
| 95 |
+
def check_correctness(examples, program):
|
| 96 |
+
'''
|
| 97 |
+
Checks if the programs output matches expected output on all examples.
|
| 98 |
+
'''
|
| 99 |
+
try:
|
| 100 |
+
inputs = [example[0] for example in examples]
|
| 101 |
+
outputs = [example[1] for example in examples]
|
| 102 |
+
program_output = [program.evaluate(input) for input in inputs]
|
| 103 |
+
except:
|
| 104 |
+
return False
|
| 105 |
+
|
| 106 |
+
print(program.str())
|
| 107 |
+
print(program_output)
|
| 108 |
+
|
| 109 |
+
# TODO return number that match and return this
|
| 110 |
+
|
| 111 |
+
return program_output == outputs
|
| 112 |
+
|
| 113 |
+
# COMPARE TYPE SIGNATURES
|
| 114 |
+
def compare_types(list1, list2):
|
| 115 |
+
for idx, type1 in enumerate(list1):
|
| 116 |
+
if idx >= len(list2):
|
| 117 |
+
return False # The first list is longer than the second list
|
| 118 |
+
|
| 119 |
+
type2 = list2[idx]
|
| 120 |
+
|
| 121 |
+
# Check if type2 is a Union
|
| 122 |
+
if hasattr(type2, '__origin__') and type2.__origin__ is Union:
|
| 123 |
+
# Extract types from Union
|
| 124 |
+
types_in_union2 = get_args(type2)
|
| 125 |
+
# Check if type1 is a Union
|
| 126 |
+
if hasattr(type1, '__origin__') and type1.__origin__ is Union:
|
| 127 |
+
types_in_union1 = get_args(type1)
|
| 128 |
+
# Check if all types in type1's Union are in type2's Union
|
| 129 |
+
if not all(any(t1 == t2 for t2 in types_in_union2) for t1 in types_in_union1):
|
| 130 |
+
return False
|
| 131 |
+
else:
|
| 132 |
+
# Check if type1 is in type2's Union
|
| 133 |
+
if not any(type1 == t2 for t2 in types_in_union2):
|
| 134 |
+
return False
|
| 135 |
+
else:
|
| 136 |
+
# Direct type comparison
|
| 137 |
+
if type1 != type2:
|
| 138 |
+
return False
|
| 139 |
+
|
| 140 |
+
return True
|
| 141 |
+
|
| 142 |
+
# RUN SYNTHESIZER
|
| 143 |
+
def run_synthesizer(examples, max_weight):
|
| 144 |
+
'''
|
| 145 |
+
Run bottom-up enumerative synthesis.
|
| 146 |
+
'''
|
| 147 |
+
program_bank = rasp_consts
|
| 148 |
+
program_bank_str = [p.str() for p in program_bank]
|
| 149 |
+
|
| 150 |
+
# TODO: store approximate programs, measured by number of output examples that match
|
| 151 |
+
|
| 152 |
+
# iterate over each level
|
| 153 |
+
for weight in range(2, max_weight):
|
| 154 |
+
|
| 155 |
+
for op in rasp_operators:
|
| 156 |
+
combinations = itertools.permutations(program_bank, op.n_args)
|
| 157 |
+
|
| 158 |
+
for combination in combinations:
|
| 159 |
+
|
| 160 |
+
type_signature = [p.return_type for p in combination]
|
| 161 |
+
|
| 162 |
+
if not compare_types(type_signature, op.arg_types):
|
| 163 |
+
continue
|
| 164 |
+
|
| 165 |
+
if sum([p.weight for p in combination]) > weight:
|
| 166 |
+
continue
|
| 167 |
+
|
| 168 |
+
program = OperatorNode(op, combination)
|
| 169 |
+
|
| 170 |
+
if program.str() in program_bank_str:
|
| 171 |
+
continue
|
| 172 |
+
|
| 173 |
+
if any([check_obs_equivalence(examples, program, p) for p in program_bank]):
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
program_bank.append(program)
|
| 177 |
+
program_bank_str.append(program.str())
|
| 178 |
+
|
| 179 |
+
if check_correctness(examples, program):
|
| 180 |
+
return(program)
|
| 181 |
+
|
| 182 |
+
return None
|
| 183 |
+
|
| 184 |
+
# COMPILE RASP MODEL
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
+
|
| 187 |
+
'''
|
| 188 |
+
Some examples:
|
| 189 |
+
Identify anagrams:
|
| 190 |
+
[[['V','I','W',',','W','I','V'], [True, True, True, True, True, True, True]],[['a','b',',','b','a'], [True, True, True, True, True]],[['e','l',',','s','t'], [False, False, False, False, False]]]
|
| 191 |
+
Output: times out
|
| 192 |
+
Calculate the median of a list of numbers:
|
| 193 |
+
[[[1,2,3,4,5], [3,3,3,3,3]], [[2,8,10,11], [9,9,9,9]], [[1,2,3],[2,2,2]]]
|
| 194 |
+
Output: times out
|
| 195 |
+
Identity function:
|
| 196 |
+
[[['h','i'], ['h','i']]]
|
| 197 |
+
Output: (aggregate((select(tokens, tokens, ==)), tokens))
|
| 198 |
+
Histogram:
|
| 199 |
+
[[['h', 'e', 'l', 'l', 'o'], [1,1,2,2,1]]]
|
| 200 |
+
Output: (select_width((select(tokens, tokens, ==))))
|
| 201 |
+
Length:
|
| 202 |
+
[[[7,2,5],[3,3,3]],[[1],[1]],[[2,0,1,7,3,6,8,20],[8,8,8,8,8,8,8,8]]]
|
| 203 |
+
Output: (select_width((select(tokens, tokens, true))))
|
| 204 |
+
Calculate mean of list of numbers:
|
| 205 |
+
[[[5,10,3,2,43], [12.6, 12.6, 12.6, 12.6, 12.6]],[[1,2], [1.5, 1.5]],[[3,3,3],[3,3,3]]]
|
| 206 |
+
Output: (aggregate((select(tokens, tokens, true)), tokens))
|
| 207 |
+
Reverse a string:
|
| 208 |
+
[[['h', 'i'], ['i', 'h']]]
|
| 209 |
+
Output: times out
|
| 210 |
+
Expected: aggregate(select(indices, (select_width((select(tokens, tokens, true)))) - indices - 1, ==), tokens);
|
| 211 |
+
PERSONAL TODOS:
|
| 212 |
+
- output several similar programs
|
| 213 |
+
-
|
| 214 |
+
|
| 215 |
+
'''
|
| 216 |
+
|
| 217 |
+
args = parse_args()
|
| 218 |
+
inputs, outs = analyze_examples(args.examples)
|
| 219 |
+
examples = list(zip(inputs, outs))
|
| 220 |
+
print("Received the following input and output examples:")
|
| 221 |
+
print(examples)
|
| 222 |
+
max_seq_len = 0
|
| 223 |
+
for i in inputs:
|
| 224 |
+
max_seq_len = max(len(i), max_seq_len)
|
| 225 |
+
vocab = get_vocabulary(examples)
|
| 226 |
+
|
| 227 |
+
print("Running synthesizer with")
|
| 228 |
+
print("Vocab: {}".format(vocab))
|
| 229 |
+
print("Max sequence length: {}".format(max_seq_len))
|
| 230 |
+
print("Max weight: {}".format(args.max_weight))
|
| 231 |
+
|
| 232 |
+
program = run_synthesizer(examples, args.max_weight)
|
| 233 |
+
|
| 234 |
+
if program:
|
| 235 |
+
algorithm = program.to_python()
|
| 236 |
+
|
| 237 |
+
bos = "BOS"
|
| 238 |
+
model = compiling.compile_rasp_to_model(
|
| 239 |
+
algorithm,
|
| 240 |
+
vocab=vocab,
|
| 241 |
+
max_seq_len=max_seq_len,
|
| 242 |
+
compiler_bos=bos,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def extract_layer_number(s):
|
| 247 |
+
match = re.search(r'layer_(\d+)', s)
|
| 248 |
+
if match:
|
| 249 |
+
return int(match.group(1)) + 1
|
| 250 |
+
else:
|
| 251 |
+
return None
|
| 252 |
+
|
| 253 |
+
layer_num = extract_layer_number(list(model.params.keys())[-1])
|
| 254 |
+
print(f"The following program has been compiled to a transformer with {layer_num} layer(s):")
|
| 255 |
+
print(program.str())
|
| 256 |
+
else:
|
| 257 |
+
print("No program found.")
|
reverse-viz.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
testouts.txt
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Received the following input and output examples:
|
| 2 |
+
[(['h', 'i'], ['h', 'h'])]
|
| 3 |
+
Running synthesizer with
|
| 4 |
+
Vocab: {'h', 'i'}
|
| 5 |
+
Max sequence length: 2
|
| 6 |
+
Max weight: 15
|
| 7 |
+
- Searching level 2 with 4 primitives.
|
| 8 |
+
- Searching level 3 with 4 primitives.
|
| 9 |
+
(select(tokens, tokens, ==))
|
| 10 |
+
[[[True, False], [False, True]]]
|
| 11 |
+
(select(tokens, tokens, true))
|
| 12 |
+
[[[True, True], [True, True]]]
|
| 13 |
+
- Searching level 4 with 6 primitives.
|
| 14 |
+
(select_width((select(tokens, tokens, ==))))
|
| 15 |
+
[[1, 1]]
|
| 16 |
+
(select_width((select(tokens, tokens, true))))
|
| 17 |
+
[[2, 2]]
|
| 18 |
+
- Searching level 5 with 8 primitives.
|
| 19 |
+
- Searching level 6 with 8 primitives.
|
| 20 |
+
- Searching level 7 with 8 primitives.
|
| 21 |
+
- Searching level 8 with 8 primitives.
|
| 22 |
+
- Searching level 9 with 8 primitives.
|
| 23 |
+
(aggregate((select(tokens, tokens, ==)), (select_width((select(tokens, tokens, ==))))))
|
| 24 |
+
[[1.0, 1.0]]
|
| 25 |
+
(aggregate((select(tokens, tokens, ==)), (select_width((select(tokens, tokens, true))))))
|
| 26 |
+
[[2.0, 2.0]]
|
| 27 |
+
(aggregate((select(tokens, tokens, true)), (select_width((select(tokens, tokens, ==))))))
|
| 28 |
+
[[1.0, 1.0]]
|
| 29 |
+
(aggregate((select(tokens, tokens, true)), (select_width((select(tokens, tokens, true))))))
|
| 30 |
+
[[2.0, 2.0]]
|
| 31 |
+
- Searching level 10 with 12 primitives.
|
| 32 |
+
- Searching level 11 with 12 primitives.
|
| 33 |
+
- Searching level 12 with 12 primitives.
|
| 34 |
+
- Searching level 13 with 12 primitives.
|
| 35 |
+
- Searching level 14 with 12 primitives.
|
| 36 |
+
(aggregate((select(tokens, tokens, ==)), (aggregate((select(tokens, tokens, ==)), (select_width((select(tokens, tokens, ==))))))))
|
| 37 |
+
[[1.0, 1.0]]
|
| 38 |
+
(aggregate((select(tokens, tokens, ==)), (aggregate((select(tokens, tokens, ==)), (select_width((select(tokens, tokens, true))))))))
|
| 39 |
+
[[2.0, 2.0]]
|
| 40 |
+
(aggregate((select(tokens, tokens, ==)), (aggregate((select(tokens, tokens, true)), (select_width((select(tokens, tokens, ==))))))))
|
| 41 |
+
[[1.0, 1.0]]
|
| 42 |
+
(aggregate((select(tokens, tokens, ==)), (aggregate((select(tokens, tokens, true)), (select_width((select(tokens, tokens, true))))))))
|
| 43 |
+
[[2.0, 2.0]]
|
| 44 |
+
> c:\users\18084\desktop\cs252r\final_project\tracr-synthesis\rasp_synthesizer.py(94)check_obs_equivalence()
|
| 45 |
+
-> return a_output == b_output
|
| 46 |
+
(Pdb) --KeyboardInterrupt--
|
| 47 |
+
(Pdb) --KeyboardInterrupt--
|
| 48 |
+
(Pdb) --KeyboardInterrupt--
|
| 49 |
+
(Pdb) *** SyntaxError: invalid syntax
|
| 50 |
+
(Pdb) --KeyboardInterrupt--
|
| 51 |
+
(Pdb) *** SyntaxError: invalid syntax
|
| 52 |
+
(Pdb) --KeyboardInterrupt--
|
| 53 |
+
(Pdb) --KeyboardInterrupt--
|
| 54 |
+
(Pdb) *** SyntaxError: invalid syntax
|
| 55 |
+
(Pdb)
|
tracr/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (168 Bytes). View file
|
|
|
tracr/compiler/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (341 Bytes). View file
|
|
|
tracr/compiler/__pycache__/assemble.cpython-39.pyc
ADDED
|
Binary file (9.98 kB). View file
|
|
|
tracr/compiler/__pycache__/basis_inference.cpython-39.pyc
ADDED
|
Binary file (2.97 kB). View file
|
|
|
tracr/compiler/__pycache__/compiling.cpython-39.pyc
ADDED
|
Binary file (2.48 kB). View file
|
|
|
tracr/compiler/__pycache__/craft_graph_to_model.cpython-39.pyc
ADDED
|
Binary file (6.71 kB). View file
|
|
|
tracr/compiler/__pycache__/craft_model_to_transformer.cpython-39.pyc
ADDED
|
Binary file (1.69 kB). View file
|
|
|
tracr/compiler/__pycache__/expr_to_craft_graph.cpython-39.pyc
ADDED
|
Binary file (7.4 kB). View file
|
|
|
tracr/compiler/__pycache__/nodes.cpython-39.pyc
ADDED
|
Binary file (442 Bytes). View file
|
|
|
tracr/compiler/__pycache__/rasp_to_graph.cpython-39.pyc
ADDED
|
Binary file (1.78 kB). View file
|
|
|
tracr/craft/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (174 Bytes). View file
|
|
|
tracr/craft/__pycache__/bases.cpython-39.pyc
ADDED
|
Binary file (10 kB). View file
|
|
|
tracr/craft/__pycache__/transformers.cpython-39.pyc
ADDED
|
Binary file (7.64 kB). View file
|
|
|
tracr/craft/__pycache__/vectorspace_fns.cpython-39.pyc
ADDED
|
Binary file (5.32 kB). View file
|
|
|
tracr/craft/chamber/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (182 Bytes). View file
|
|
|
tracr/craft/chamber/__pycache__/categorical_attn.cpython-39.pyc
ADDED
|
Binary file (4.25 kB). View file
|
|
|
tracr/craft/chamber/__pycache__/categorical_mlp.cpython-39.pyc
ADDED
|
Binary file (5.04 kB). View file
|
|
|
tracr/craft/chamber/__pycache__/numerical_mlp.cpython-39.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
tracr/craft/chamber/__pycache__/selector_width.cpython-39.pyc
ADDED
|
Binary file (4.54 kB). View file
|
|
|
tracr/rasp/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (173 Bytes). View file
|
|
|
tracr/rasp/__pycache__/rasp.cpython-39.pyc
ADDED
|
Binary file (36.6 kB). View file
|
|
|
tracr/transformer/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (180 Bytes). View file
|
|
|
tracr/transformer/__pycache__/attention.cpython-39.pyc
ADDED
|
Binary file (4.83 kB). View file
|
|
|
tracr/transformer/__pycache__/encoder.cpython-39.pyc
ADDED
|
Binary file (5.39 kB). View file
|
|
|
tracr/transformer/__pycache__/model.cpython-39.pyc
ADDED
|
Binary file (5.25 kB). View file
|
|
|
tracr/utils/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (174 Bytes). View file
|
|
|
tracr/utils/__pycache__/errors.cpython-39.pyc
ADDED
|
Binary file (928 Bytes). View file
|
|
|
utils.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
import time
|
| 3 |
+
import re
|
| 4 |
+
import ast
|
| 5 |
+
|
| 6 |
+
# Start the REPL subprocess
|
| 7 |
+
python_exe = '/Users/18084/Desktop/CS252R/final_project/rasp-env-py3.9/Scripts/python.exe' #SETUP THING: replace with path to your python environment
|
| 8 |
+
|
| 9 |
+
'''
|
| 10 |
+
THE FOLLOWING FUNCTIONS ARE DEPRECATED
|
| 11 |
+
'''
|
| 12 |
+
def clean_carrots(text):
|
| 13 |
+
pattern = r">>(.*?)>>"
|
| 14 |
+
|
| 15 |
+
match = re.search(pattern, text)
|
| 16 |
+
if match:
|
| 17 |
+
result = match.group(1).strip() # .strip() is used to remove any leading/trailing whitespace
|
| 18 |
+
return result
|
| 19 |
+
|
| 20 |
+
def parse_output(out):
|
| 21 |
+
out = clean_carrots(out)
|
| 22 |
+
out = ast.literal_eval(out)
|
| 23 |
+
# can arrive as tuple, list, or dictionary
|
| 24 |
+
# ultimately want to convert everything to list form
|
| 25 |
+
if isinstance(out, dict):
|
| 26 |
+
return list(out.values())
|
| 27 |
+
if isinstance(out, tuple):
|
| 28 |
+
return list(out)
|
| 29 |
+
if isinstance(out, list):
|
| 30 |
+
return list
|
| 31 |
+
raise Exception("Error executing rasp program.")
|
| 32 |
+
|
| 33 |
+
def run_repl(command):
|
| 34 |
+
'''
|
| 35 |
+
Runs the RASP repl in a separate subprocess.
|
| 36 |
+
'''
|
| 37 |
+
process = subprocess.Popen([python_exe, 'RASP/RASP_support/REPL.py'],
|
| 38 |
+
stdin=subprocess.PIPE,
|
| 39 |
+
stdout=subprocess.PIPE,
|
| 40 |
+
stderr=subprocess.PIPE,
|
| 41 |
+
text=True)
|
| 42 |
+
|
| 43 |
+
# Send commands to the REPL
|
| 44 |
+
process.stdin.write(f'{command}\nexit()\n')
|
| 45 |
+
process.stdin.flush()
|
| 46 |
+
|
| 47 |
+
# Check periodically if the subprocess has terminated
|
| 48 |
+
while True:
|
| 49 |
+
if process.poll() is not None:
|
| 50 |
+
# The subprocess has terminated
|
| 51 |
+
break
|
| 52 |
+
time.sleep(0.1) # Wait for a short period (e.g., 0.1 seconds) before checking again
|
| 53 |
+
|
| 54 |
+
# Close the subprocess if still running
|
| 55 |
+
if process.poll() is None:
|
| 56 |
+
process.terminate()
|
| 57 |
+
|
| 58 |
+
# Read output and error
|
| 59 |
+
output = process.stdout.readlines()
|
| 60 |
+
error = process.stderr.readlines()
|
| 61 |
+
|
| 62 |
+
# Print output and error
|
| 63 |
+
str_output = ""
|
| 64 |
+
str_error = ""
|
| 65 |
+
for line in output:
|
| 66 |
+
str_output += line.strip() + " "
|
| 67 |
+
for line in error:
|
| 68 |
+
str_error += line.strip() + " "
|
| 69 |
+
|
| 70 |
+
str_output = parse_output(str_output)
|
| 71 |
+
return str_output, str_error
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
command = "select(tokens, tokens, ==)(\"hi\");"
|
| 75 |
+
res, _res_err = run_repl(command)
|
| 76 |
+
print(res)
|
| 77 |
+
|
| 78 |
+
command = "selector_width(select(tokens, tokens, ==))(\"hi\");"
|
| 79 |
+
res, _res_err = run_repl(command)
|
| 80 |
+
print(res)
|