Commit
·
7145fd6
1
Parent(s):
e78b7eb
Add REMEND python module
Browse files- pyproject.toml +30 -0
- remend/__init__.py +0 -0
- remend/bpe.py +64 -0
- remend/bpe_apply.py +25 -0
- remend/change_eqn_format.py +79 -0
- remend/check_generated.py +143 -0
- remend/compile_dataset.py +185 -0
- remend/compile_eqn.sh +44 -0
- remend/convert_generated.py +24 -0
- remend/deduplicate_split.py +111 -0
- remend/disassemble.py +553 -0
- remend/edit_model.py +16 -0
- remend/eval_generated.py +100 -0
- remend/experiment.py +75 -0
- remend/find_duplicates.py +69 -0
- remend/implementation.py +210 -0
- remend/parser.py +449 -0
- remend/plot_loss.py +60 -0
- remend/preprocess_remaqe.py +102 -0
- remend/util.py +21 -0
pyproject.toml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["hatchling"]
|
| 3 |
+
build-backend = "hatchling.build"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "remend"
|
| 7 |
+
version = "1.0"
|
| 8 |
+
authors = [{name="Meet Udeshi", email="m.udeshi@nyu.edu"}]
|
| 9 |
+
description = "Neural Decompilation for Reverse Engineering Math Equations from Binary Executables"
|
| 10 |
+
readme = "README.md"
|
| 11 |
+
classifiers = [
|
| 12 |
+
"Programming Language :: Python :: 3",
|
| 13 |
+
"Operating System :: OS Independent",
|
| 14 |
+
]
|
| 15 |
+
requires-python = ">=3.9"
|
| 16 |
+
dependencies = [
|
| 17 |
+
"networkx",
|
| 18 |
+
"capstone",
|
| 19 |
+
"Levenshtein",
|
| 20 |
+
"tqdm",
|
| 21 |
+
"numpy",
|
| 22 |
+
"sympy",
|
| 23 |
+
"fairseq",
|
| 24 |
+
"torch",
|
| 25 |
+
"matplotlib",
|
| 26 |
+
"tokenizers"
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
[tool.hatch.build.targets.wheel]
|
| 30 |
+
packages = ["remend"]
|
remend/__init__.py
ADDED
|
File without changes
|
remend/bpe.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tokenizers import pre_tokenizers, Tokenizer
|
| 2 |
+
from tokenizers.models import BPE
|
| 3 |
+
from tokenizers.trainers import BpeTrainer
|
| 4 |
+
from tokenizers.pre_tokenizers import Whitespace, PreTokenizer
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import itertools as it
|
| 9 |
+
|
| 10 |
+
class ImmPreTokenizer:
|
| 11 |
+
def pre_tokenize(self, pretok):
|
| 12 |
+
pretok.split(self.hex_imm_split)
|
| 13 |
+
def hex_imm_split(self, i, norm_str):
|
| 14 |
+
tok = str(norm_str)
|
| 15 |
+
if tok[:2] == "0x" or tok.isdigit():
|
| 16 |
+
return [norm_str[i:i+1] for i in range(len(tok))]
|
| 17 |
+
else:
|
| 18 |
+
return [norm_str]
|
| 19 |
+
|
| 20 |
+
def get_asm_tok(files, save):
|
| 21 |
+
asm_tok = Tokenizer(BPE(unk_token="@@UNK@@"))
|
| 22 |
+
asm_tok.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), PreTokenizer.custom(ImmPreTokenizer())])
|
| 23 |
+
asm_train = BpeTrainer(special_tokens=["@@UNK@@"])
|
| 24 |
+
|
| 25 |
+
asm_tok.train(files, asm_train)
|
| 26 |
+
asm_tok.pre_tokenizer = Whitespace() # Hack to save, careful to restore ImmPreTokenizer
|
| 27 |
+
asm_tok.save(save)
|
| 28 |
+
asm_tok.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), PreTokenizer.custom(ImmPreTokenizer())])
|
| 29 |
+
|
| 30 |
+
return asm_tok
|
| 31 |
+
|
| 32 |
+
def load_asm_tok(load):
|
| 33 |
+
asm_tok = Tokenizer.from_file(load)
|
| 34 |
+
asm_tok.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), PreTokenizer.custom(ImmPreTokenizer())])
|
| 35 |
+
return asm_tok
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if __name__ == "__main__":
|
| 39 |
+
import argparse
|
| 40 |
+
parser = argparse.ArgumentParser("Train the tokenizer and tokenize the asm")
|
| 41 |
+
parser.add_argument("-i", "--indir", required=True, help="output directory")
|
| 42 |
+
parser.add_argument("-o", "--outdir", default="tokenized", help="output directory")
|
| 43 |
+
args = parser.parse_args()
|
| 44 |
+
|
| 45 |
+
os.makedirs(args.outdir, exist_ok=True)
|
| 46 |
+
injoin = lambda p: os.path.join(args.indir, p)
|
| 47 |
+
pjoin = lambda p: os.path.join(args.outdir, p)
|
| 48 |
+
max_asm_toks = 0
|
| 49 |
+
|
| 50 |
+
asm_tok = get_asm_tok([injoin("train.asm"), injoin("valid.asm")], pjoin("asm_tokens.json"))
|
| 51 |
+
for split in ["train", "valid", "test"]:
|
| 52 |
+
asmfile = split + ".asm"
|
| 53 |
+
with open(injoin(asmfile), "r") as asmf, open(pjoin(asmfile), "w") as asmtokf:
|
| 54 |
+
for asm in tqdm(asmf, desc=f"Tokenizing {split}"):
|
| 55 |
+
asm = asm.strip()
|
| 56 |
+
asm_enc = asm_tok.encode(asm)
|
| 57 |
+
max_asm_toks = max(max_asm_toks, len(asm_enc.tokens))
|
| 58 |
+
asm_seq = " ".join(asm_enc.tokens)
|
| 59 |
+
asmtokf.write(asm_seq + "\n")
|
| 60 |
+
|
| 61 |
+
print("Maximum tokens:", max_asm_toks)
|
| 62 |
+
|
| 63 |
+
# After this, run command:
|
| 64 |
+
# fairseq-preprocess -s asm -t eqn --trainpref {OUTDIR}/train --validpref {OUTDIR}/valid --testpref {OUTDIR}/test --destdir {OUTDIR}
|
remend/bpe_apply.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tqdm import tqdm
|
| 2 |
+
|
| 3 |
+
from .bpe import load_asm_tok
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
import argparse
|
| 7 |
+
parser = argparse.ArgumentParser("Tokenize using existing tokenizer")
|
| 8 |
+
parser.add_argument("-t", "--tokenizer", required=True, help="existing tokenizer")
|
| 9 |
+
parser.add_argument("-i", "--input", required=True, help="input file")
|
| 10 |
+
parser.add_argument("-o", "--output", required=True, help="output file")
|
| 11 |
+
args = parser.parse_args()
|
| 12 |
+
|
| 13 |
+
max_asm_toks = 0
|
| 14 |
+
asm_tok = load_asm_tok(args.tokenizer)
|
| 15 |
+
|
| 16 |
+
with open(args.input, "r") as asmf, open(args.output, "w") as asmtokf:
|
| 17 |
+
for asm in tqdm(asmf, desc=f"Tokenizing"):
|
| 18 |
+
asm = asm.strip()
|
| 19 |
+
asm_enc = asm_tok.encode(asm)
|
| 20 |
+
max_asm_toks = max(max_asm_toks, len(asm_enc.tokens))
|
| 21 |
+
asm_seq = " ".join(asm_enc.tokens)
|
| 22 |
+
asmtokf.write(asm_seq + "\n")
|
| 23 |
+
|
| 24 |
+
print("Maximum tokens:", max_asm_toks)
|
| 25 |
+
|
remend/change_eqn_format.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .parser import isint, OPERATORS
|
| 2 |
+
|
| 3 |
+
def prefix_to_brackets(eqn):
|
| 4 |
+
stack = []
|
| 5 |
+
lastop = []
|
| 6 |
+
intunit = []
|
| 7 |
+
N = len(eqn)
|
| 8 |
+
i = 0
|
| 9 |
+
while i < N:
|
| 10 |
+
# print("Stack", stack)
|
| 11 |
+
val = eqn[i]
|
| 12 |
+
if val.startswith("INT"):
|
| 13 |
+
intunit.append(val)
|
| 14 |
+
i += 1
|
| 15 |
+
while i < N and isint(eqn[i]):
|
| 16 |
+
intunit.append(eqn[i])
|
| 17 |
+
i += 1
|
| 18 |
+
stack.append(" ".join(intunit))
|
| 19 |
+
intunit = []
|
| 20 |
+
i -= 1
|
| 21 |
+
elif val in OPERATORS:
|
| 22 |
+
_, numops = OPERATORS[val]
|
| 23 |
+
lastop.append((len(stack), numops))
|
| 24 |
+
stack.append(val)
|
| 25 |
+
else:
|
| 26 |
+
stack.append(val)
|
| 27 |
+
|
| 28 |
+
while len(lastop) > 0 and len(stack) > lastop[-1][0] + lastop[-1][1]:
|
| 29 |
+
# Combine op
|
| 30 |
+
# print(lastop[-1], stack[lastop[-1][0]:])
|
| 31 |
+
op = " ".join(stack[lastop[-1][0]:])
|
| 32 |
+
del stack[lastop[-1][0]:]
|
| 33 |
+
lastop.pop()
|
| 34 |
+
stack.append(f"( {op} )")
|
| 35 |
+
i += 1
|
| 36 |
+
assert(len(stack) == 1)
|
| 37 |
+
return stack[0]
|
| 38 |
+
|
| 39 |
+
def prefix_to_postfix(eqn):
|
| 40 |
+
if eqn[0].startswith("INT"):
|
| 41 |
+
intunit = [eqn[0]]
|
| 42 |
+
for i, val in enumerate(eqn[1:]):
|
| 43 |
+
if not isint(val):
|
| 44 |
+
break
|
| 45 |
+
intunit.append(val)
|
| 46 |
+
return intunit, eqn[i+1:]
|
| 47 |
+
elif eqn[0] in OPERATORS:
|
| 48 |
+
_, numops = OPERATORS[eqn[0]]
|
| 49 |
+
remeqn = eqn[1:]
|
| 50 |
+
ops = []
|
| 51 |
+
for i in range(numops):
|
| 52 |
+
op, remeqn = prefix_to_postfix(remeqn)
|
| 53 |
+
ops.extend(op)
|
| 54 |
+
ops.append(eqn[0]) # Restructured to postfix
|
| 55 |
+
return ops, remeqn
|
| 56 |
+
else:
|
| 57 |
+
return [eqn[0]], eqn[1:]
|
| 58 |
+
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
import argparse
|
| 61 |
+
parser = argparse.ArgumentParser("Change equation format from prefix to other")
|
| 62 |
+
parser.add_argument("--eqn", required=True)
|
| 63 |
+
parser.add_argument("--out", required=True)
|
| 64 |
+
args = parser.parse_args()
|
| 65 |
+
|
| 66 |
+
with open(args.eqn, "r") as inf, open(args.out, "w") as outf:
|
| 67 |
+
for eqn in inf:
|
| 68 |
+
postfix, _ = prefix_to_postfix(eqn.strip().split(" "))
|
| 69 |
+
outf.write(" ".join(postfix) + "\n")
|
| 70 |
+
|
| 71 |
+
# eqn = "div mul x add INT+ 5 add mul INT+ 3 x mul pow x INT+ 2 add INT- 5 add mul INT- 3 x mul x mul add INT+ 1 mul k0 pow x INT+ 3 add INT+ 4 x add INT+ 5 mul INT+ 3 x"
|
| 72 |
+
# eqn = "div add mul INT+ 3 x pow x INT- 4 mul sub x k0 add mul INT+ 5 x k1"
|
| 73 |
+
# print(" ".join(prefix_to_postfix(eqn.split(" "))[0]))
|
| 74 |
+
# postfix = "x INT+ 3 mul INT+ 5 add x INT+ 4 add INT+ 3 x pow k0 mul INT+ 1 add mul x mul x INT- 3 mul add INT- 5 add INT+ 2 x pow mul x INT+ 3 mul add INT+ 5 add x mul div"
|
| 75 |
+
# print(prefix_to_brackets(eqn.split(" ")))
|
| 76 |
+
|
| 77 |
+
# (div (mul x (add INT+ 5 (add (mul INT+ 3 x) (mul (pow x INT+ 2) (add INT- 5 (add (mul INT- 3 x) (mul x (mul (add INT+ 1 (mul k0 (pow x INT+ 3))) (add INT+ 4 x))))))))) (add INT+ 5 (mul INT+ 3 x)))
|
| 78 |
+
# ( div ( mul x ( add INT+ 5 ( add ( mul INT+ 3 x ) ( mul ( pow x INT+ 2 ) ( add INT- 5 ( add ( mul INT- 3 x ) ( mul x ( mul ( add INT+ 1 ( mul k0 ( pow x INT+ 3 ) ) ) ( add INT+ 4 x ) ) ) ) ) ) ) ) ) ( add INT+ 5 ( mul INT+ 3 x ) ) )
|
| 79 |
+
|
remend/check_generated.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sympy as sp
|
| 2 |
+
import sys
|
| 3 |
+
import re
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from Levenshtein import distance
|
| 6 |
+
import networkx as nx
|
| 7 |
+
from networkx import graph_edit_distance
|
| 8 |
+
|
| 9 |
+
from .parser import parse_prefix_to_sympy, parse_postfix_to_sympy, isint
|
| 10 |
+
|
| 11 |
+
def percent(a, n):
|
| 12 |
+
return f"{a/n*100:0.1f}%"
|
| 13 |
+
|
| 14 |
+
def do_simplify_match(orig_expr, gen_expr):
|
| 15 |
+
orig_simp = sp.simplify(orig_expr)
|
| 16 |
+
gen_simp = sp.simplify(gen_expr)
|
| 17 |
+
if orig_simp == gen_simp:
|
| 18 |
+
return True
|
| 19 |
+
return False
|
| 20 |
+
|
| 21 |
+
def do_structure_match(orig_toks, gen_toks):
|
| 22 |
+
def _isconst(t):
|
| 23 |
+
return re.match(r"c[0-9]+", t)
|
| 24 |
+
def _isvar(t):
|
| 25 |
+
return re.match(r"x[0-9]+", t)
|
| 26 |
+
if len(orig_toks) != len(gen_toks):
|
| 27 |
+
return False
|
| 28 |
+
for orig, gen in zip(orig_toks, gen_toks):
|
| 29 |
+
if (_isconst(orig) and _isconst(gen)) \
|
| 30 |
+
or (_isvar(orig) and _isvar(gen)) \
|
| 31 |
+
or (isint(orig) and isint(gen)) \
|
| 32 |
+
or (orig.startswith("INT") and gen.startswith("INT")) \
|
| 33 |
+
or (orig == gen):
|
| 34 |
+
continue
|
| 35 |
+
# Mismatched
|
| 36 |
+
return False
|
| 37 |
+
return True
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
import argparse
|
| 41 |
+
parser = argparse.ArgumentParser("Check generated expressions")
|
| 42 |
+
parser.add_argument("-g", required=True, help="Generated expressions file")
|
| 43 |
+
parser.add_argument("-r", required=True, help="Results file")
|
| 44 |
+
parser.add_argument("--simplify", action="store_true", default=False)
|
| 45 |
+
parser.add_argument("--postfix", action="store_true", default=False)
|
| 46 |
+
args = parser.parse_args()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
orig_list = []
|
| 50 |
+
gen_list = []
|
| 51 |
+
with open(args.g, 'r') as f:
|
| 52 |
+
for line in tqdm(f, desc="Reading file"):
|
| 53 |
+
comps = line.strip().split("\t")
|
| 54 |
+
if line[0] == 'T':
|
| 55 |
+
num = int(comps[0][2:])
|
| 56 |
+
tokens = comps[1].split(" ")
|
| 57 |
+
orig_list.append((num, tokens))
|
| 58 |
+
elif line[0] == 'H':
|
| 59 |
+
num = int(comps[0][2:])
|
| 60 |
+
tokens = comps[2].split(" ")
|
| 61 |
+
gen_list.append((num, tokens))
|
| 62 |
+
|
| 63 |
+
N = len(orig_list)
|
| 64 |
+
gen_errors = []
|
| 65 |
+
parsed = []
|
| 66 |
+
exact_match = []
|
| 67 |
+
structure_match = []
|
| 68 |
+
simplify_match = []
|
| 69 |
+
|
| 70 |
+
orig_exprs = {}
|
| 71 |
+
gen_exprs = {}
|
| 72 |
+
|
| 73 |
+
all_aed = []
|
| 74 |
+
# all_ged = []
|
| 75 |
+
|
| 76 |
+
results = []
|
| 77 |
+
|
| 78 |
+
for (orig_num, orig_toks), (gen_num, gen_toks) in tqdm(zip(orig_list, gen_list), desc="Parsing expressions", total=N):
|
| 79 |
+
assert orig_num == gen_num
|
| 80 |
+
aed = distance(orig_toks, gen_toks) / (len(orig_toks) + len(gen_toks))
|
| 81 |
+
all_aed.append(aed)
|
| 82 |
+
res = {"id": gen_num, "aed": aed, "matched": False, "parsed": False}
|
| 83 |
+
|
| 84 |
+
if aed == 0:
|
| 85 |
+
parsed.append(orig_num)
|
| 86 |
+
exact_match.append(orig_num)
|
| 87 |
+
structure_match.append(orig_num)
|
| 88 |
+
res["parsed"] = True
|
| 89 |
+
res["matched"] = "Exact"
|
| 90 |
+
results.append(res)
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
if do_structure_match(orig_toks, gen_toks):
|
| 94 |
+
structure_match.append(orig_num)
|
| 95 |
+
res["matched"] = "Structure"
|
| 96 |
+
|
| 97 |
+
if "<<unk>>" in orig_toks:
|
| 98 |
+
# Why this happened?
|
| 99 |
+
res["parsed"] = False
|
| 100 |
+
res["matched"] = False
|
| 101 |
+
results.append(res)
|
| 102 |
+
continue
|
| 103 |
+
|
| 104 |
+
if args.postfix:
|
| 105 |
+
orig_expr = parse_postfix_to_sympy(orig_toks)
|
| 106 |
+
else:
|
| 107 |
+
orig_expr = parse_prefix_to_sympy(orig_toks)
|
| 108 |
+
try:
|
| 109 |
+
if args.postfix:
|
| 110 |
+
gen_expr = parse_postfix_to_sympy(gen_toks)
|
| 111 |
+
else:
|
| 112 |
+
gen_expr = parse_prefix_to_sympy(gen_toks)
|
| 113 |
+
res["parsed"] = True
|
| 114 |
+
except: # Exception as e:
|
| 115 |
+
gen_errors.append(gen_num)
|
| 116 |
+
results.append(res)
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
parsed.append(gen_num)
|
| 120 |
+
orig_exprs[gen_num] = orig_expr
|
| 121 |
+
gen_exprs[gen_num] = gen_expr
|
| 122 |
+
|
| 123 |
+
if orig_expr == gen_expr:
|
| 124 |
+
exact_match.append(gen_num)
|
| 125 |
+
res["matched"] = "Exact"
|
| 126 |
+
elif args.simplify and do_simplify_match(orig_expr, gen_expr):
|
| 127 |
+
simplify_match.append(gen_num)
|
| 128 |
+
res["matched"] = "Simplify"
|
| 129 |
+
results.append(res)
|
| 130 |
+
|
| 131 |
+
with open(args.r, "w") as resf:
|
| 132 |
+
for res in results:
|
| 133 |
+
resf.write("{id} {aed} {parsed} {matched}\n".format(**res))
|
| 134 |
+
resf.write("\n")
|
| 135 |
+
print("Total", N, file=resf)
|
| 136 |
+
print("Parse error", len(gen_errors), percent(len(gen_errors), N), file=resf)
|
| 137 |
+
print("Exact match", len(exact_match), percent(len(exact_match), N), file=resf)
|
| 138 |
+
print("Structure match", len(structure_match), percent(len(structure_match), N), file=resf)
|
| 139 |
+
if args.simplify:
|
| 140 |
+
print("Simplify match", len(simplify_match), percent(len(simplify_match), N), file=resf)
|
| 141 |
+
print("Avg SED", sum(all_aed) / len(all_aed), max(all_aed), file=resf)
|
| 142 |
+
# print("Avg GED", sum(all_ged) / len(all_ged), max(all_ged), file=resf)
|
| 143 |
+
|
remend/compile_dataset.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tqdm import tqdm
|
| 2 |
+
import random
|
| 3 |
+
import sympy as sp
|
| 4 |
+
import json
|
| 5 |
+
import subprocess as sproc
|
| 6 |
+
from os.path import realpath, dirname, join as pjoin
|
| 7 |
+
from os import makedirs
|
| 8 |
+
import multiprocessing as mp
|
| 9 |
+
from time import sleep
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
from .implementation import Implementor
|
| 13 |
+
from .parser import parse_prefix_to_sympy, sympy_to_prefix, constant_fold
|
| 14 |
+
from .disassemble import DisassemblerARM32, DisassemblerAArch64, DisassemblerX64
|
| 15 |
+
from .util import DecodeError, timeout, sympy_expr_ok
|
| 16 |
+
|
| 17 |
+
SCRIPT = pjoin(dirname(realpath(__file__)), "compile_eqn.sh")
|
| 18 |
+
|
| 19 |
+
QUEUE_END = "QUEUE_END_SENTINEL"
|
| 20 |
+
|
| 21 |
+
def compile_c(code, elf, arch="arm32", src="/tmp/myfunc.c", opt=0):
|
| 22 |
+
with open(src, "w") as f:
|
| 23 |
+
f.write(code)
|
| 24 |
+
ret = sproc.run(["bash", "-e", SCRIPT, arch+"-c", src, elf, f"-O{opt}"], capture_output=True)
|
| 25 |
+
if ret.returncode != 0:
|
| 26 |
+
raise DecodeError("compile failed")
|
| 27 |
+
|
| 28 |
+
def compile_fortran(code, elf, arch="arm32", src="/tmp/myfunc.f95", opt=0):
|
| 29 |
+
with open(src, "w") as f:
|
| 30 |
+
f.write(code)
|
| 31 |
+
ret = sproc.run(["bash", "-e", SCRIPT, arch+"-fortran", src, elf, f"-O{opt}"], capture_output=True)
|
| 32 |
+
if ret.returncode != 0:
|
| 33 |
+
raise DecodeError("compile failed")
|
| 34 |
+
|
| 35 |
+
class EquationCompiler:
|
| 36 |
+
def __init__(self, q, arch, impl, opt, outdir, prefix, dtype="double"):
|
| 37 |
+
if "fortran" in impl:
|
| 38 |
+
self.compiler = compile_fortran
|
| 39 |
+
else:
|
| 40 |
+
self.compiler = compile_c
|
| 41 |
+
|
| 42 |
+
if arch == "arm32":
|
| 43 |
+
self.disassembler = DisassemblerARM32
|
| 44 |
+
elif arch == "aarch64":
|
| 45 |
+
self.disassembler = DisassemblerAArch64
|
| 46 |
+
elif arch == "x64":
|
| 47 |
+
self.disassembler = DisassemblerX64
|
| 48 |
+
else:
|
| 49 |
+
raise DecodeError("arch not supported: " + arch)
|
| 50 |
+
|
| 51 |
+
self.q = q
|
| 52 |
+
self.impl = impl
|
| 53 |
+
self.opt = opt
|
| 54 |
+
self.outdir = outdir
|
| 55 |
+
self.prefix = prefix
|
| 56 |
+
self.dtype = dtype
|
| 57 |
+
self.arch = arch
|
| 58 |
+
|
| 59 |
+
def run(self):
|
| 60 |
+
outdir = pjoin(self.outdir, f"O{self.opt}", self.impl)
|
| 61 |
+
makedirs(outdir, exist_ok=True)
|
| 62 |
+
outfiles = {
|
| 63 |
+
"asm": open(pjoin(outdir, self.prefix + ".asm"), "w"),
|
| 64 |
+
"eqn": open(pjoin(outdir, self.prefix + ".eqn"), "w"),
|
| 65 |
+
"src": open(pjoin(outdir, self.prefix + ".src"), "w"),
|
| 66 |
+
"const": open(pjoin(outdir, self.prefix + ".const.jsonl"), "w"),
|
| 67 |
+
"err": open(pjoin(outdir, self.prefix + ".error"), "w")
|
| 68 |
+
}
|
| 69 |
+
l = 0
|
| 70 |
+
tmpsrc = f"/tmp/myfunc_{self.impl}_{self.opt}_{self.prefix}"
|
| 71 |
+
if "fortran" in self.impl:
|
| 72 |
+
tmpsrc += ".f95"
|
| 73 |
+
func = "myfunc_"
|
| 74 |
+
else:
|
| 75 |
+
tmpsrc += ".c"
|
| 76 |
+
func = "myfunc"
|
| 77 |
+
tmpelf = f"/tmp/myfunc_{self.arch}_{self.impl}_{self.opt}_{self.prefix}.elf"
|
| 78 |
+
|
| 79 |
+
while True:
|
| 80 |
+
data = self.q.get()
|
| 81 |
+
if data == QUEUE_END:
|
| 82 |
+
# Queue is closed, break from inf loop
|
| 83 |
+
break
|
| 84 |
+
n, expr, expr_const, pref = data
|
| 85 |
+
impl = Implementor(expr, constants=expr_const, dtype=self.dtype)
|
| 86 |
+
try:
|
| 87 |
+
code = impl.implement(self.impl)
|
| 88 |
+
self.compiler(code, tmpelf, arch=self.arch, src=tmpsrc, opt=self.opt)
|
| 89 |
+
disasm = self.disassembler(tmpelf, expr_constants=expr_const,
|
| 90 |
+
match_constants=True)
|
| 91 |
+
asm = disasm.disassemble(func)
|
| 92 |
+
if len(disasm.constants) < len(expr_const):
|
| 93 |
+
print(n, "constants not identified", disasm.constants, expr_const,
|
| 94 |
+
file=outfiles["err"])
|
| 95 |
+
continue
|
| 96 |
+
except DecodeError as e:
|
| 97 |
+
print(n, "impl error", e, expr, expr_const, pref, file=outfiles["err"])
|
| 98 |
+
continue
|
| 99 |
+
|
| 100 |
+
outfiles["asm"].write(asm + "\n")
|
| 101 |
+
outfiles["eqn"].write(pref + "\n")
|
| 102 |
+
outfiles["src"].write(f"==== pick={n} line={l} ====\n" + code + "\n")
|
| 103 |
+
outfiles["const"].write(json.dumps(expr_const) + "\n")
|
| 104 |
+
l += 1
|
| 105 |
+
|
| 106 |
+
for f in outfiles:
|
| 107 |
+
outfiles[f].close()
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
if __name__ == "__main__":
|
| 111 |
+
import argparse
|
| 112 |
+
parser = argparse.ArgumentParser("Compile prefix to asm->eqn dataset")
|
| 113 |
+
parser.add_argument("-f", "--file", required=True, help="Input file")
|
| 114 |
+
parser.add_argument("--outdir", required=True, help="Output directory")
|
| 115 |
+
parser.add_argument("--prefix", required=True, help="File prefix")
|
| 116 |
+
parser.add_argument("--impl", nargs="+", required=True,
|
| 117 |
+
choices=["dag_c", "cse_c", "dag_fortran", "cse_fortran"])
|
| 118 |
+
parser.add_argument("--pick", type=float, required=True,
|
| 119 |
+
help="Ratio of samples to pick (0 to 1)")
|
| 120 |
+
parser.add_argument("--start", type=int, default=0, help="Start from index")
|
| 121 |
+
parser.add_argument("--count", type=int, default=0, help="Process only these many")
|
| 122 |
+
parser.add_argument("--seed", type=int, default=1225)
|
| 123 |
+
parser.add_argument("--min-tokens", help="Minimum tokens in equations", type=int, default=5)
|
| 124 |
+
parser.add_argument("--min-ops", help="Minimum ops in equations", type=int, default=5)
|
| 125 |
+
parser.add_argument("--dtype", help="Implementation datatype", type=str,
|
| 126 |
+
choices=["double", "float"], default="double")
|
| 127 |
+
parser.add_argument("--arch", help="Target architecture", type=str,
|
| 128 |
+
choices=["arm32", "aarch64", "x64"], default="arm32")
|
| 129 |
+
parser.add_argument("-O", "--opt", nargs="+", type=int, choices=[0, 1, 2, 3], default=[0],
|
| 130 |
+
help="Optimization level (s)")
|
| 131 |
+
|
| 132 |
+
# Dont show warnings
|
| 133 |
+
logging.getLogger("cle").setLevel(logging.ERROR)
|
| 134 |
+
|
| 135 |
+
args = parser.parse_args()
|
| 136 |
+
random.seed(args.seed)
|
| 137 |
+
|
| 138 |
+
eqcompilers = [EquationCompiler(mp.Queue(), args.arch, impl, opt, args.outdir, args.prefix, dtype=args.dtype)
|
| 139 |
+
for impl in args.impl
|
| 140 |
+
for opt in args.opt]
|
| 141 |
+
pool = [mp.Process(target=eqc.run, args=()) for eqc in eqcompilers]
|
| 142 |
+
for proc in pool:
|
| 143 |
+
proc.start()
|
| 144 |
+
|
| 145 |
+
count = 0
|
| 146 |
+
prefixf = open(args.file, "r")
|
| 147 |
+
for n, line in tqdm(enumerate(prefixf), desc="Parsing file"):
|
| 148 |
+
# Skip for start lines and with some probability
|
| 149 |
+
if n < args.start or random.random() > args.pick:
|
| 150 |
+
continue
|
| 151 |
+
comps = line.strip().split("\t")
|
| 152 |
+
pref = comps[0][comps[0].find("Y'")+3:]
|
| 153 |
+
prefl = pref.split(" ")
|
| 154 |
+
# pref = comps[1].split(" ")
|
| 155 |
+
if len(prefl) < args.min_tokens:
|
| 156 |
+
continue
|
| 157 |
+
try:
|
| 158 |
+
expr = parse_prefix_to_sympy(prefl)
|
| 159 |
+
with timeout(10):
|
| 160 |
+
expr = sp.simplify(expr)
|
| 161 |
+
if not sympy_expr_ok(expr):
|
| 162 |
+
# Simplified is bad
|
| 163 |
+
continue
|
| 164 |
+
expr, expr_const = constant_fold(expr)
|
| 165 |
+
pref = " ".join(sympy_to_prefix(expr))
|
| 166 |
+
except:
|
| 167 |
+
continue
|
| 168 |
+
|
| 169 |
+
if sp.count_ops(expr) < args.min_ops:
|
| 170 |
+
continue
|
| 171 |
+
|
| 172 |
+
for eqc in eqcompilers:
|
| 173 |
+
# Poll on this queue to get empty
|
| 174 |
+
while eqc.q.qsize() > 5:
|
| 175 |
+
sleep(1)
|
| 176 |
+
eqc.q.put((n, expr, expr_const, pref))
|
| 177 |
+
count += 1
|
| 178 |
+
if args.count > 0 and count >= args.count:
|
| 179 |
+
break
|
| 180 |
+
|
| 181 |
+
# Close queues
|
| 182 |
+
for eqc in eqcompilers:
|
| 183 |
+
eqc.q.put(QUEUE_END)
|
| 184 |
+
for proc in pool:
|
| 185 |
+
proc.join()
|
remend/compile_eqn.sh
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!bin/bash
|
| 2 |
+
|
| 3 |
+
MODE=$1
|
| 4 |
+
SRC=$2
|
| 5 |
+
ELF=$3
|
| 6 |
+
OPT=$4
|
| 7 |
+
|
| 8 |
+
if [ ! -f "$SRC" ]
|
| 9 |
+
then
|
| 10 |
+
echo "Please provide source file"
|
| 11 |
+
exit 1
|
| 12 |
+
fi
|
| 13 |
+
|
| 14 |
+
if [ "$ELF" == "" ]
|
| 15 |
+
then
|
| 16 |
+
echo "Please provide elf file path"
|
| 17 |
+
exit 1
|
| 18 |
+
fi
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
if [ "$MODE" == "arm32-c" ]
|
| 22 |
+
then
|
| 23 |
+
arm-linux-gnueabihf-gcc $OPT $SRC -lm -o $ELF
|
| 24 |
+
elif [ "$MODE" == "arm32-fortran" ]
|
| 25 |
+
then
|
| 26 |
+
arm-linux-gnueabihf-gfortran -std=gnu $OPT $SRC -o $ELF
|
| 27 |
+
elif [ "$MODE" == "aarch64-c" ]
|
| 28 |
+
then
|
| 29 |
+
aarch64-linux-gnu-gcc $OPT $SRC -lm -o $ELF
|
| 30 |
+
elif [ "$MODE" == "aarch64-fortran" ]
|
| 31 |
+
then
|
| 32 |
+
aarch64-linux-gnu-gfortran -std=gnu $OPT $SRC -o $ELF
|
| 33 |
+
elif [ "$MODE" == "x64-c" ]
|
| 34 |
+
then
|
| 35 |
+
gcc $OPT $SRC -lm -o $ELF
|
| 36 |
+
elif [ "$MODE" == "x64-fortran" ]
|
| 37 |
+
then
|
| 38 |
+
gfortran -std=gnu $OPT $SRC -o $ELF
|
| 39 |
+
else
|
| 40 |
+
echo "Incorrect mode: $MODE. Choose from: {arm32,aarch64,x64}-{c,fortran}"
|
| 41 |
+
exit 1
|
| 42 |
+
fi
|
| 43 |
+
|
| 44 |
+
# arm-linux-gnueabihf-objdump --no-show-raw-insn --no-addresses -d $1.elf | sed -n -e 's/\s;\s.*$//' -e "/myfunc>:$/,/^$/p" | sed '1d;$d' | tr '\n' ' '
|
remend/convert_generated.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .parser import parse_prefix_to_sympy
|
| 2 |
+
|
| 3 |
+
if __name__ == "__main__":
|
| 4 |
+
import argparse
|
| 5 |
+
parser = argparse.ArgumentParser("Parse result prefix to equation")
|
| 6 |
+
parser.add_argument("--input", required=True, help="Input result file")
|
| 7 |
+
args = parser.parse_args()
|
| 8 |
+
|
| 9 |
+
res_list = []
|
| 10 |
+
|
| 11 |
+
with open(args.input, 'r') as f:
|
| 12 |
+
for line in f:
|
| 13 |
+
comps = line.strip().split("\t")
|
| 14 |
+
if line[0] == 'H':
|
| 15 |
+
num = int(comps[0][2:])
|
| 16 |
+
tokens = comps[2].split(" ")
|
| 17 |
+
res_list.append((num, tokens))
|
| 18 |
+
|
| 19 |
+
for n, toks in res_list:
|
| 20 |
+
try:
|
| 21 |
+
ex = parse_prefix_to_sympy(toks)
|
| 22 |
+
print(n, ex)
|
| 23 |
+
except Exception as e:
|
| 24 |
+
print(n, "could not parse:", str(e))
|
remend/deduplicate_split.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import random
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
def filter_poly(asm, eqn):
|
| 8 |
+
rejects = {"ln", "exp", "sin", "cos", "sqrt", "tan", "asin", "acos", "atan", "E", "pi", "cot"}
|
| 9 |
+
return any(t in rejects for t in asm.strip().split(" ")) \
|
| 10 |
+
or any(t in rejects for t in eqn.strip().split(" "))
|
| 11 |
+
|
| 12 |
+
def filter_bigint(asm, eqn):
|
| 13 |
+
if re.search(r"CONST=[0-9]{4,}", asm):
|
| 14 |
+
return True
|
| 15 |
+
return False
|
| 16 |
+
|
| 17 |
+
if __name__ == "__main__":
|
| 18 |
+
import argparse
|
| 19 |
+
parser = argparse.ArgumentParser("Deduplicate ASM and split files into train/test/valid")
|
| 20 |
+
parser.add_argument("--inprefix", required=True, help="Prefix of input files")
|
| 21 |
+
parser.add_argument("--outdir", required=True)
|
| 22 |
+
parser.add_argument("--split", type=float, default=0.05)
|
| 23 |
+
parser.add_argument("--seed", type=int, default=1225)
|
| 24 |
+
parser.add_argument("--filter", choices=["poly", "bigint"], default=None)
|
| 25 |
+
parser.add_argument("--no-separate-eqn", action="store_true")
|
| 26 |
+
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
eq_mapped = {}
|
| 30 |
+
combined_ds = []
|
| 31 |
+
asm_hash = set()
|
| 32 |
+
removed = 0
|
| 33 |
+
|
| 34 |
+
with open(args.inprefix + ".asm", "r") as asmf, \
|
| 35 |
+
open(args.inprefix + ".eqn", "r") as eqnf, \
|
| 36 |
+
open(args.inprefix + ".const.jsonl", "r") as constf:
|
| 37 |
+
for i, (asm, eqn, const) in tqdm(enumerate(zip(asmf, eqnf, constf)),
|
| 38 |
+
desc="Read files", leave=False):
|
| 39 |
+
h = hash(asm)
|
| 40 |
+
if h in asm_hash:
|
| 41 |
+
# Skip this repeated line
|
| 42 |
+
removed += 1
|
| 43 |
+
continue
|
| 44 |
+
|
| 45 |
+
if re.search(r"[0-9]\.[0-9]", eqn):
|
| 46 |
+
# Float not represented, remove
|
| 47 |
+
removed += 1
|
| 48 |
+
continue
|
| 49 |
+
|
| 50 |
+
if args.filter == "poly" and filter_poly(asm, eqn):
|
| 51 |
+
removed += 1
|
| 52 |
+
continue
|
| 53 |
+
if args.filter == "bigint" and filter_bigint(asm, eqn):
|
| 54 |
+
removed += 1
|
| 55 |
+
continue
|
| 56 |
+
|
| 57 |
+
asm_hash.add(h)
|
| 58 |
+
if args.no_separate_eqn:
|
| 59 |
+
combined_ds.append((i, asm, eqn, const))
|
| 60 |
+
else:
|
| 61 |
+
if eqn not in eq_mapped:
|
| 62 |
+
eq_mapped[eqn] = []
|
| 63 |
+
eq_mapped[eqn].append((i, asm, const))
|
| 64 |
+
|
| 65 |
+
print("Removed", removed)
|
| 66 |
+
|
| 67 |
+
if args.no_separate_eqn:
|
| 68 |
+
dataset = combined_ds
|
| 69 |
+
else:
|
| 70 |
+
dataset = list(eq_mapped.keys())
|
| 71 |
+
|
| 72 |
+
random.seed(args.seed)
|
| 73 |
+
random.shuffle(dataset)
|
| 74 |
+
|
| 75 |
+
N = len(dataset)
|
| 76 |
+
Ntest = int(N * args.split)
|
| 77 |
+
|
| 78 |
+
splits = {
|
| 79 |
+
"train": dataset[:N-2*Ntest],
|
| 80 |
+
"valid": dataset[N-2*Ntest:N-Ntest],
|
| 81 |
+
"test": dataset[N-Ntest:]
|
| 82 |
+
}
|
| 83 |
+
splitidxs = {s: [] for s in splits}
|
| 84 |
+
|
| 85 |
+
idxf = open(os.path.join(args.outdir, "splits.txt"), "w")
|
| 86 |
+
for s in splits:
|
| 87 |
+
asmfn = os.path.join(args.outdir, f"{s}.asm")
|
| 88 |
+
eqnfn = os.path.join(args.outdir, f"{s}.eqn")
|
| 89 |
+
constfn = os.path.join(args.outdir, f"{s}.const.jsonl")
|
| 90 |
+
with open(asmfn, "w") as asmf, open(eqnfn, "w") as eqnf, \
|
| 91 |
+
open(constfn, "w") as constf:
|
| 92 |
+
if args.no_separate_eqn:
|
| 93 |
+
for i, asm, eqn, const in splits[s]:
|
| 94 |
+
asmf.write(asm)
|
| 95 |
+
eqnf.write(eqn)
|
| 96 |
+
constf.write(const)
|
| 97 |
+
splitidxs[s].append(i)
|
| 98 |
+
else:
|
| 99 |
+
for eqn in splits[s]:
|
| 100 |
+
for i, asm, const in eq_mapped[eqn]:
|
| 101 |
+
asmf.write(asm)
|
| 102 |
+
eqnf.write(eqn)
|
| 103 |
+
constf.write(const)
|
| 104 |
+
splitidxs[s].append(i)
|
| 105 |
+
print("Split", s, len(splitidxs[s]))
|
| 106 |
+
idxf.write(f"==== {s} ====\n")
|
| 107 |
+
for j, i in enumerate(splitidxs[s]):
|
| 108 |
+
idxf.write(f"{j}: {i}\n")
|
| 109 |
+
idxf.write("\n")
|
| 110 |
+
idxf.close()
|
| 111 |
+
|
remend/disassemble.py
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from capstone import *
|
| 2 |
+
from capstone.arm import *
|
| 3 |
+
from capstone.arm64 import *
|
| 4 |
+
from capstone.x86 import *
|
| 5 |
+
import cle
|
| 6 |
+
import struct
|
| 7 |
+
from math import e as CONST_E, pi as CONST_PI
|
| 8 |
+
import sympy as sp
|
| 9 |
+
|
| 10 |
+
from .util import DecodeError
|
| 11 |
+
|
| 12 |
+
def int2fp32(v):
|
| 13 |
+
if type(v) == int:
|
| 14 |
+
v = struct.unpack("<f", v.to_bytes(4, "little"))
|
| 15 |
+
v = v[0]
|
| 16 |
+
return v
|
| 17 |
+
def int2fp64(v):
|
| 18 |
+
if type(v) == int:
|
| 19 |
+
v = struct.unpack("<d", v.to_bytes(8, "little"))
|
| 20 |
+
v = v[0]
|
| 21 |
+
return v
|
| 22 |
+
|
| 23 |
+
def align4(v):
|
| 24 |
+
return v & (0xFFFFFFFC)
|
| 25 |
+
|
| 26 |
+
class DisassemblerBase:
|
| 27 |
+
def __init__(self, expr_constants={}, match_constants=False):
|
| 28 |
+
self.loader = None # Load in child class
|
| 29 |
+
self.reg_values = {}
|
| 30 |
+
self.constidx = 0
|
| 31 |
+
self.constants = {}
|
| 32 |
+
self.constaddrs = set()
|
| 33 |
+
self.expr_constants = expr_constants
|
| 34 |
+
self.match_constants = match_constants
|
| 35 |
+
|
| 36 |
+
def get_function_bytes(self, funcname):
|
| 37 |
+
func = self.loader.find_symbol(funcname)
|
| 38 |
+
if not func:
|
| 39 |
+
raise DecodeError(f"Function {funcname} not found in binary")
|
| 40 |
+
faddr = func.rebased_addr
|
| 41 |
+
if (not isinstance(self, DisassemblerX64)) and faddr % 2 == 1:
|
| 42 |
+
# Unaligned address, aligning
|
| 43 |
+
faddr = faddr - 1
|
| 44 |
+
fbytes = self.loader.memory.load(faddr, func.size)
|
| 45 |
+
self.funcrange = faddr, faddr + func.size
|
| 46 |
+
return faddr, fbytes
|
| 47 |
+
|
| 48 |
+
def find_constant(self, constants, value):
|
| 49 |
+
for ec in constants:
|
| 50 |
+
if abs(value - constants[ec]) < 1e-5:
|
| 51 |
+
return ec, ""
|
| 52 |
+
elif abs(1/value - constants[ec]) < 1e-5:
|
| 53 |
+
return ec, "1/"
|
| 54 |
+
elif abs(-value - constants[ec]) < 1e-5:
|
| 55 |
+
return ec, "-"
|
| 56 |
+
elif abs(-1/value - constants[ec]) < 1e-5:
|
| 57 |
+
return ec, "-1/"
|
| 58 |
+
return False
|
| 59 |
+
|
| 60 |
+
def add_constant(self, value, addr=0, size=0):
|
| 61 |
+
# Don't map known constants like e, pi, 0
|
| 62 |
+
if value == 0:
|
| 63 |
+
cname = "CONST=0"
|
| 64 |
+
elif abs(value - CONST_E) < 1e-7:
|
| 65 |
+
cname = "CONST=E"
|
| 66 |
+
elif abs(value - CONST_PI) < 1e-7:
|
| 67 |
+
cname = "CONST=pi"
|
| 68 |
+
elif self.match_constants and \
|
| 69 |
+
(ecmatch := self.find_constant(self.expr_constants, value)):
|
| 70 |
+
# Gives the name and expression of the matched constant
|
| 71 |
+
ecname, ecxpr = ecmatch
|
| 72 |
+
# print(value, ecname, ecxpr, self.expr_constants[ecname])
|
| 73 |
+
cname = f"{ecxpr}CSYM{ecname[1:]}"
|
| 74 |
+
self.constants[ecname] = value
|
| 75 |
+
elif size > 0 and addr in self.constaddrs and \
|
| 76 |
+
(smatch := self.find_constant(self.constants, value)):
|
| 77 |
+
sname, sxpr = smatch
|
| 78 |
+
cname = f"{sxpr}CSYM{sname}"
|
| 79 |
+
else:
|
| 80 |
+
rep = sp.nsimplify(value, [sp.E, sp.pi], tolerance=1e-7)
|
| 81 |
+
if isinstance(rep, sp.Integer) or \
|
| 82 |
+
(isinstance(rep, sp.Rational) and rep.q <= 16):
|
| 83 |
+
cname = f"CONST={rep}"
|
| 84 |
+
elif not self.match_constants:
|
| 85 |
+
cname = f"CSYM{self.constidx}"
|
| 86 |
+
self.constants[self.constidx] = value
|
| 87 |
+
self.constidx += 1
|
| 88 |
+
else:
|
| 89 |
+
raise DecodeError(f"Cannot represent unmatched float {value}")
|
| 90 |
+
|
| 91 |
+
if size > 0:
|
| 92 |
+
self.constaddrs |= {addr+i for i in range(size)}
|
| 93 |
+
return cname
|
| 94 |
+
|
| 95 |
+
def disassemble(self, function):
|
| 96 |
+
raise NotImplementedError("Call disassemble on child classes, not base")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class DisassemblerARM32(DisassemblerBase):
|
| 100 |
+
def __init__(self, binpath, expr_constants={}, match_constants=False):
|
| 101 |
+
super().__init__(expr_constants=expr_constants, match_constants=match_constants)
|
| 102 |
+
self.md = Cs(CS_ARCH_ARM, CS_MODE_THUMB)
|
| 103 |
+
self.md.detail = True
|
| 104 |
+
self.loader = cle.Loader(binpath)
|
| 105 |
+
|
| 106 |
+
def check_mov_imm(self, insn):
|
| 107 |
+
if insn.id not in {ARM_INS_MOV, ARM_INS_MOVW,
|
| 108 |
+
ARM_INS_MOVT, ARM_INS_ADR}:
|
| 109 |
+
return False
|
| 110 |
+
ops = list(insn.operands)
|
| 111 |
+
if len(ops) != 2:
|
| 112 |
+
return False
|
| 113 |
+
if ops[0].type != ARM_OP_REG or ops[1].type != ARM_OP_IMM:
|
| 114 |
+
return False
|
| 115 |
+
imm = ops[1].value.imm
|
| 116 |
+
if imm < 0:
|
| 117 |
+
imm = 2**32 + imm # 2's complement
|
| 118 |
+
if insn.id == ARM_INS_ADR:
|
| 119 |
+
# Add PC value
|
| 120 |
+
imm += insn.address + 4
|
| 121 |
+
return ops[0].value.reg, imm
|
| 122 |
+
|
| 123 |
+
def check_float_store(self, insn):
|
| 124 |
+
if insn.id not in {ARM_INS_STR, ARM_INS_STRD}:
|
| 125 |
+
return False
|
| 126 |
+
ops = list(insn.operands)
|
| 127 |
+
if insn.id == ARM_INS_STRD:
|
| 128 |
+
dest = ops[0].value.reg
|
| 129 |
+
dest2 = ops[1].value.reg
|
| 130 |
+
if dest not in self.reg_values or dest2 not in self.reg_values:
|
| 131 |
+
return False
|
| 132 |
+
fval = int2fp64((self.reg_values[dest2]<<32) + self.reg_values[dest])
|
| 133 |
+
else:
|
| 134 |
+
dest = ops[0].value.reg
|
| 135 |
+
if dest not in self.reg_values:
|
| 136 |
+
return False
|
| 137 |
+
fval = int2fp32(self.reg_values[dest])
|
| 138 |
+
if abs(fval) < 1e-3 or abs(fval) > 100:
|
| 139 |
+
return False
|
| 140 |
+
return fval
|
| 141 |
+
|
| 142 |
+
def check_ldrd(self, insn):
|
| 143 |
+
if insn.id != ARM_INS_LDRD:
|
| 144 |
+
return False
|
| 145 |
+
ops = insn.op_str.split(", ")
|
| 146 |
+
if len(ops) != 3:
|
| 147 |
+
return False
|
| 148 |
+
mem = ops[2] # format: [<reg> + #<offset>]
|
| 149 |
+
if mem[0] != "[" or mem[-1] != "]":
|
| 150 |
+
return False
|
| 151 |
+
memcomps = mem[1:-1].split(" ")
|
| 152 |
+
if memcomps[0] == "pc":
|
| 153 |
+
base = align4(insn.address + 4)
|
| 154 |
+
else:
|
| 155 |
+
basereg = ARM_REG_R0 + int(memcomps[0][1:]) # Shitty hack, may malfunction
|
| 156 |
+
if basereg not in self.reg_values:
|
| 157 |
+
return False
|
| 158 |
+
base = align4(self.reg_values[basereg])
|
| 159 |
+
if len(memcomps) == 3:
|
| 160 |
+
offset = int(memcomps[2][1:])
|
| 161 |
+
else:
|
| 162 |
+
offset = 0
|
| 163 |
+
addr = base + offset
|
| 164 |
+
fhex = self.loader.memory.load(addr, 8)
|
| 165 |
+
fval = struct.unpack("d", fhex)[0]
|
| 166 |
+
return fval, addr, 8
|
| 167 |
+
|
| 168 |
+
def check_vldr(self, insn):
|
| 169 |
+
if insn.id != ARM_INS_VLDR:
|
| 170 |
+
return False
|
| 171 |
+
ops = list(insn.operands)
|
| 172 |
+
dest = ops[0]
|
| 173 |
+
if ops[1].type != ARM_OP_MEM:
|
| 174 |
+
return False
|
| 175 |
+
mem = ops[1].value.mem
|
| 176 |
+
if mem.base == ARM_REG_PC:
|
| 177 |
+
# Align4(PC) + Imm
|
| 178 |
+
# For whatever reason, in Thumb PC=addr+4
|
| 179 |
+
addr = align4(insn.address + 4) + mem.disp
|
| 180 |
+
elif mem.base in self.reg_values:
|
| 181 |
+
addr = align4(self.reg_values[mem.base]) + mem.disp
|
| 182 |
+
else:
|
| 183 |
+
return False
|
| 184 |
+
if addr < self.loader.min_addr or addr + 8 > self.loader.max_addr:
|
| 185 |
+
# Out of bounds
|
| 186 |
+
return False
|
| 187 |
+
if dest.value.reg >= ARM_REG_D0 and dest.value.reg <= ARM_REG_D31:
|
| 188 |
+
size = 8
|
| 189 |
+
fhex = self.loader.memory.load(addr, 8)
|
| 190 |
+
fval = struct.unpack("d", fhex)[0]
|
| 191 |
+
else:
|
| 192 |
+
size = 4
|
| 193 |
+
fhex = self.loader.memory.load(addr, 4)
|
| 194 |
+
fval = struct.unpack("f", fhex)[0]
|
| 195 |
+
return fval, addr, size
|
| 196 |
+
|
| 197 |
+
def check_vmov(self, insn):
|
| 198 |
+
# fconsts/d == vmov.f32/f64 (old/new names)
|
| 199 |
+
if insn.id not in {ARM_INS_FCONSTS, ARM_INS_FCONSTD}:
|
| 200 |
+
return False
|
| 201 |
+
ops = list(insn.operands)
|
| 202 |
+
if len(ops) != 2 or ops[1].type != ARM_OP_FP:
|
| 203 |
+
return False
|
| 204 |
+
fval = ops[1].value.fp
|
| 205 |
+
destname = insn.reg_name(ops[0].value.reg)
|
| 206 |
+
asm = f"{insn.mnemonic} {destname}, {fval}"
|
| 207 |
+
return asm, fval
|
| 208 |
+
|
| 209 |
+
def check_branch_symbol(self, insn):
|
| 210 |
+
if insn.id not in {ARM_INS_B, ARM_INS_BL, ARM_INS_BLX}:
|
| 211 |
+
return False
|
| 212 |
+
ops = list(insn.operands)
|
| 213 |
+
if len(ops) != 1 or ops[0].type != ARM_OP_IMM:
|
| 214 |
+
return False
|
| 215 |
+
addr = ops[0].value.imm
|
| 216 |
+
if addr > self.funcrange[0] and addr < self.funcrange[1]:
|
| 217 |
+
# Self-branch
|
| 218 |
+
func = f"SELF+{hex(addr - self.funcrange[0])}"
|
| 219 |
+
else:
|
| 220 |
+
func = self.loader.find_plt_stub_name(addr)
|
| 221 |
+
if func is None:
|
| 222 |
+
# Some tail call optimized PLT stubs have extra instructions
|
| 223 |
+
# that are not identified by CLE, so check with offset of 4 also.
|
| 224 |
+
func = self.loader.find_plt_stub_name(addr + 4)
|
| 225 |
+
if func is None:
|
| 226 |
+
return False
|
| 227 |
+
asm = f"{insn.mnemonic} <{func}>"
|
| 228 |
+
return asm
|
| 229 |
+
|
| 230 |
+
def get_function_bytes(self, funcname):
|
| 231 |
+
func = self.loader.find_symbol(funcname)
|
| 232 |
+
if not func:
|
| 233 |
+
raise DecodeError(f"Function {funcname} not found in binary")
|
| 234 |
+
faddr = func.rebased_addr
|
| 235 |
+
if faddr % 2 == 1:
|
| 236 |
+
# Unaligned address, aligning
|
| 237 |
+
faddr = faddr - 1
|
| 238 |
+
fbytes = self.loader.memory.load(faddr, func.size)
|
| 239 |
+
self.funcrange = faddr, faddr + func.size
|
| 240 |
+
return faddr, fbytes
|
| 241 |
+
|
| 242 |
+
def disassemble(self, funcname):
|
| 243 |
+
funcaddr, funcbytes = self.get_function_bytes(funcname)
|
| 244 |
+
disassm = []
|
| 245 |
+
|
| 246 |
+
for insn in self.md.disasm(funcbytes, funcaddr):
|
| 247 |
+
if insn.address in self.constaddrs:
|
| 248 |
+
# Skip if this is a constant value and not instruction
|
| 249 |
+
continue
|
| 250 |
+
|
| 251 |
+
cname = None
|
| 252 |
+
asm = None
|
| 253 |
+
|
| 254 |
+
if vldr := self.check_vldr(insn):
|
| 255 |
+
fval, faddr, fsize = vldr
|
| 256 |
+
cname = self.add_constant(fval, faddr, fsize)
|
| 257 |
+
elif ldrd := self.check_ldrd(insn):
|
| 258 |
+
fval, faddr, fsize = ldrd
|
| 259 |
+
cname = self.add_constant(fval, faddr, fsize)
|
| 260 |
+
elif strfloat := self.check_float_store(insn):
|
| 261 |
+
fval = strfloat
|
| 262 |
+
cname = self.add_constant(fval)
|
| 263 |
+
elif vmovfloat := self.check_vmov(insn):
|
| 264 |
+
asm, fval = vmovfloat
|
| 265 |
+
cname = self.add_constant(fval)
|
| 266 |
+
elif branch := self.check_branch_symbol(insn):
|
| 267 |
+
asm = branch
|
| 268 |
+
|
| 269 |
+
# Maintain values of immediate moves.
|
| 270 |
+
# Needs to be done after processing current instruction.
|
| 271 |
+
if movimm := self.check_mov_imm(insn):
|
| 272 |
+
reg, imm = movimm
|
| 273 |
+
if insn.id == ARM_INS_MOVT:
|
| 274 |
+
if reg not in self.reg_values:
|
| 275 |
+
self.reg_values[reg] = 0
|
| 276 |
+
self.reg_values[reg] += imm << 16
|
| 277 |
+
else:
|
| 278 |
+
self.reg_values[reg] = imm
|
| 279 |
+
else:
|
| 280 |
+
reads, writes = insn.regs_access()
|
| 281 |
+
for r in writes:
|
| 282 |
+
# Remove this reg if written to
|
| 283 |
+
if r in self.reg_values:
|
| 284 |
+
del self.reg_values[r]
|
| 285 |
+
|
| 286 |
+
if not asm:
|
| 287 |
+
asm = f"{insn.mnemonic} {insn.op_str}"
|
| 288 |
+
if cname:
|
| 289 |
+
asm += f", {cname}"
|
| 290 |
+
disassm.append(asm)
|
| 291 |
+
|
| 292 |
+
fulldiss = "; ".join(disassm)
|
| 293 |
+
return fulldiss
|
| 294 |
+
|
| 295 |
+
class DisassemblerAArch64(DisassemblerBase):
|
| 296 |
+
def __init__(self, binpath, expr_constants={}, match_constants=False):
|
| 297 |
+
super().__init__(expr_constants=expr_constants, match_constants=match_constants)
|
| 298 |
+
self.md = Cs(CS_ARCH_ARM64, CS_MODE_ARM)
|
| 299 |
+
self.md.detail = True
|
| 300 |
+
self.loader = cle.Loader(binpath)
|
| 301 |
+
|
| 302 |
+
def reg_size_type(self, reg):
|
| 303 |
+
# Bit width and datatype of register
|
| 304 |
+
if reg >= ARM64_REG_W0 and reg <= ARM64_REG_W30:
|
| 305 |
+
return 32, int
|
| 306 |
+
elif reg >= ARM64_REG_X0 and reg <= ARM64_REG_X30:
|
| 307 |
+
return 64, int
|
| 308 |
+
elif reg >= ARM64_REG_S0 and reg <= ARM64_REG_S31:
|
| 309 |
+
return 32, float
|
| 310 |
+
elif reg >= ARM64_REG_D0 and reg <= ARM64_REG_D31:
|
| 311 |
+
return 64, float
|
| 312 |
+
return 0, None
|
| 313 |
+
|
| 314 |
+
def check_mov_imm(self, insn):
|
| 315 |
+
if insn.id not in {ARM64_INS_ADRP, ARM64_INS_ADR, ARM64_INS_MOV, ARM64_INS_MOVK}:
|
| 316 |
+
return False
|
| 317 |
+
|
| 318 |
+
ops = insn.operands
|
| 319 |
+
if len(ops) != 2:
|
| 320 |
+
return False
|
| 321 |
+
if ops[0].type != ARM64_OP_REG or ops[1].type != ARM64_OP_IMM:
|
| 322 |
+
return False
|
| 323 |
+
|
| 324 |
+
imm = ops[1].value.imm
|
| 325 |
+
if ops[1].shift.type == 1: # LSL
|
| 326 |
+
imm <<= ops[1].shift.value
|
| 327 |
+
mask = 0xFFFF << ops[1].shift.value
|
| 328 |
+
|
| 329 |
+
if insn.id == ARM64_INS_ADRP:
|
| 330 |
+
# imm -= 0x400000 # Subtract global offset for some reason
|
| 331 |
+
# imm = ((insn.address + 4) & (~4095)) + imm
|
| 332 |
+
# Really confused about this, maybe I can use the imm directly
|
| 333 |
+
pass
|
| 334 |
+
elif insn.id == ARM64_INS_ADR:
|
| 335 |
+
imm -= 0x400000 # Subtract global offset for some reason
|
| 336 |
+
imm += insn.address + 4
|
| 337 |
+
elif insn.id == ARM64_INS_MOVK:
|
| 338 |
+
# load previous reg value
|
| 339 |
+
if ops[0].value.reg in self.reg_values:
|
| 340 |
+
curr = self.reg_values[ops[0].value.reg]
|
| 341 |
+
imm = (imm & mask) | (curr & (~mask))
|
| 342 |
+
|
| 343 |
+
return ops[0].value.reg, imm
|
| 344 |
+
|
| 345 |
+
def check_fmov(self, insn):
|
| 346 |
+
if insn.id != ARM64_INS_FMOV:
|
| 347 |
+
return False
|
| 348 |
+
ops = insn.operands
|
| 349 |
+
if len(ops) != 2: # or ops[1].type != ARM64_OP_FP:
|
| 350 |
+
return False
|
| 351 |
+
|
| 352 |
+
destsize, _ = self.reg_size_type(ops[0].value.reg)
|
| 353 |
+
destname = insn.reg_name(ops[0].value.reg)
|
| 354 |
+
if ops[1].type == ARM64_OP_FP:
|
| 355 |
+
fval = ops[1].value.fp
|
| 356 |
+
asm = f"{insn.mnemonic} {destname}, {fval}"
|
| 357 |
+
elif ops[1].type == ARM64_OP_REG:
|
| 358 |
+
reg = ops[1].value.reg
|
| 359 |
+
if reg not in self.reg_values:
|
| 360 |
+
return False
|
| 361 |
+
# TODO datatype
|
| 362 |
+
fhex = self.reg_values[reg]
|
| 363 |
+
if destsize == 64:
|
| 364 |
+
if fhex < 0:
|
| 365 |
+
fhex += 2**64
|
| 366 |
+
fval = int2fp64(fhex)
|
| 367 |
+
elif destsize == 32:
|
| 368 |
+
if fhex < 0:
|
| 369 |
+
fhex += 2**32
|
| 370 |
+
fval = int2fp32(fhex)
|
| 371 |
+
else:
|
| 372 |
+
return False
|
| 373 |
+
|
| 374 |
+
if abs(fval) < 1e-5 or abs(fval) > 1e5:
|
| 375 |
+
return False
|
| 376 |
+
asm = f"{insn.mnemonic} {insn.op_str}"
|
| 377 |
+
return asm, fval
|
| 378 |
+
|
| 379 |
+
def check_ldr(self, insn):
|
| 380 |
+
if insn.id != ARM64_INS_LDR:
|
| 381 |
+
return False
|
| 382 |
+
ops = insn.op_str[:-1].split(", ")
|
| 383 |
+
destsize, desttype = self.reg_size_type(insn.operands[0].value.reg)
|
| 384 |
+
if len(ops) < 2 or desttype != float:
|
| 385 |
+
return False
|
| 386 |
+
reg = ops[1]
|
| 387 |
+
if reg[0] != "[" or "sp" in reg:
|
| 388 |
+
return False
|
| 389 |
+
basereg = ARM64_REG_X0 + int(reg[2:]) # Shitty hack, may malfunction
|
| 390 |
+
if basereg not in self.reg_values:
|
| 391 |
+
return False
|
| 392 |
+
base = align4(self.reg_values[basereg])
|
| 393 |
+
if len(ops) == 3:
|
| 394 |
+
offset = ops[2][1:]
|
| 395 |
+
if offset.startswith("0x"):
|
| 396 |
+
offset = int(offset[2:], base=16)
|
| 397 |
+
else:
|
| 398 |
+
offset = int(offset)
|
| 399 |
+
else:
|
| 400 |
+
offset = 0
|
| 401 |
+
addr = base + offset
|
| 402 |
+
if destsize == 64:
|
| 403 |
+
fhex = self.loader.memory.load(addr, 8)
|
| 404 |
+
fval = struct.unpack("d", fhex)[0]
|
| 405 |
+
return fval, addr, 8
|
| 406 |
+
elif destsize == 32:
|
| 407 |
+
fhex = self.loader.memory.load(addr, 4)
|
| 408 |
+
fval = struct.unpack("f", fhex)[0]
|
| 409 |
+
return fval, addr, 4
|
| 410 |
+
else:
|
| 411 |
+
return False
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def check_branch_symbol(self, insn):
|
| 415 |
+
if insn.id not in {ARM64_INS_BL, ARM64_INS_B}:
|
| 416 |
+
return False
|
| 417 |
+
ops = insn.operands
|
| 418 |
+
if len(ops) != 1 or ops[0].type != ARM_OP_IMM:
|
| 419 |
+
return False
|
| 420 |
+
addr = ops[0].value.imm
|
| 421 |
+
if addr > self.funcrange[0] and addr < self.funcrange[1]:
|
| 422 |
+
# Self-branch
|
| 423 |
+
func = f"SELF+{hex(addr - self.funcrange[0])}"
|
| 424 |
+
else:
|
| 425 |
+
func = self.loader.find_plt_stub_name(addr)
|
| 426 |
+
if func is None:
|
| 427 |
+
# Some tail call optimized PLT stubs have extra instructions
|
| 428 |
+
# that are not identified by CLE, so check with offset of 4 also.
|
| 429 |
+
func = self.loader.find_plt_stub_name(addr + 4)
|
| 430 |
+
if func is None:
|
| 431 |
+
return False
|
| 432 |
+
asm = f"{insn.mnemonic} <{func}>"
|
| 433 |
+
return asm
|
| 434 |
+
|
| 435 |
+
def disassemble(self, funcname):
|
| 436 |
+
funcaddr, funcbytes = self.get_function_bytes(funcname)
|
| 437 |
+
disassm = []
|
| 438 |
+
|
| 439 |
+
for insn in self.md.disasm(funcbytes, funcaddr):
|
| 440 |
+
if insn.address in self.constaddrs:
|
| 441 |
+
# Skip if this is a constant value and not instruction
|
| 442 |
+
continue
|
| 443 |
+
|
| 444 |
+
cname = None
|
| 445 |
+
asm = None
|
| 446 |
+
# Maintain values of immediate moves
|
| 447 |
+
if movimm := self.check_mov_imm(insn):
|
| 448 |
+
reg, imm = movimm
|
| 449 |
+
self.reg_values[reg] = imm
|
| 450 |
+
else:
|
| 451 |
+
reads, writes = insn.regs_access()
|
| 452 |
+
for r in writes:
|
| 453 |
+
# Remove this reg if written to
|
| 454 |
+
if r in self.reg_values:
|
| 455 |
+
del self.reg_values[r]
|
| 456 |
+
|
| 457 |
+
if fmov := self.check_fmov(insn):
|
| 458 |
+
asm, fval = fmov
|
| 459 |
+
cname = self.add_constant(fval)
|
| 460 |
+
elif ldr := self.check_ldr(insn):
|
| 461 |
+
fval, faddr, fsize = ldr
|
| 462 |
+
cname = self.add_constant(fval, faddr, fsize)
|
| 463 |
+
elif branch := self.check_branch_symbol(insn):
|
| 464 |
+
asm = branch
|
| 465 |
+
|
| 466 |
+
if not asm:
|
| 467 |
+
asm = f"{insn.mnemonic} {insn.op_str}"
|
| 468 |
+
if cname:
|
| 469 |
+
asm += f", {cname}"
|
| 470 |
+
disassm.append(asm)
|
| 471 |
+
|
| 472 |
+
fulldiss = "; ".join(disassm)
|
| 473 |
+
return fulldiss
|
| 474 |
+
|
| 475 |
+
class DisassemblerX64(DisassemblerBase):
|
| 476 |
+
def __init__(self, binpath, expr_constants={}, match_constants=False):
|
| 477 |
+
super().__init__(expr_constants=expr_constants, match_constants=match_constants)
|
| 478 |
+
self.md = Cs(CS_ARCH_X86, CS_MODE_64)
|
| 479 |
+
self.md.detail = True
|
| 480 |
+
self.loader = cle.Loader(binpath)
|
| 481 |
+
|
| 482 |
+
def check_call_symbol(self, insn):
|
| 483 |
+
if insn.id != X86_INS_CALL:
|
| 484 |
+
return False
|
| 485 |
+
ops = insn.operands
|
| 486 |
+
# TODO check this ARM_OP
|
| 487 |
+
if len(ops) != 1 or ops[0].type != ARM_OP_IMM:
|
| 488 |
+
return False
|
| 489 |
+
addr = ops[0].value.imm
|
| 490 |
+
func = self.loader.find_plt_stub_name(addr)
|
| 491 |
+
if func is None:
|
| 492 |
+
return False
|
| 493 |
+
asm = f"{insn.mnemonic} <{func}>"
|
| 494 |
+
return asm
|
| 495 |
+
|
| 496 |
+
def check_fload(self, insn):
|
| 497 |
+
# Cannot rely on ID because any instruction
|
| 498 |
+
# can access memory.
|
| 499 |
+
ops = insn.operands
|
| 500 |
+
memops = [op for op in ops
|
| 501 |
+
if (op.type == X86_OP_MEM and
|
| 502 |
+
op.value.mem.base == X86_REG_RIP)]
|
| 503 |
+
if len(memops) != 1:
|
| 504 |
+
return False
|
| 505 |
+
mem, size = memops[0].value.mem, memops[0].size
|
| 506 |
+
if size > 8:
|
| 507 |
+
return False
|
| 508 |
+
addr = insn.address + insn.size + mem.disp
|
| 509 |
+
fhex = self.loader.memory.load(addr, size)
|
| 510 |
+
fval = struct.unpack("f" if size == 4 else "d", fhex)[0]
|
| 511 |
+
return fval, addr, size
|
| 512 |
+
|
| 513 |
+
def disassemble(self, funcname):
|
| 514 |
+
funcaddr, funcbytes = self.get_function_bytes(funcname)
|
| 515 |
+
disassm = []
|
| 516 |
+
|
| 517 |
+
for insn in self.md.disasm(funcbytes, funcaddr):
|
| 518 |
+
asm = None
|
| 519 |
+
cname = None
|
| 520 |
+
if fload := self.check_fload(insn):
|
| 521 |
+
fval, faddr, fsize = fload
|
| 522 |
+
cname = self.add_constant(fval, faddr, fsize)
|
| 523 |
+
elif call := self.check_call_symbol(insn):
|
| 524 |
+
asm = call
|
| 525 |
+
|
| 526 |
+
if not asm:
|
| 527 |
+
asm = f"{insn.mnemonic} {insn.op_str}"
|
| 528 |
+
if cname:
|
| 529 |
+
asm += f", {cname}"
|
| 530 |
+
disassm.append(asm)
|
| 531 |
+
|
| 532 |
+
fulldiss = "; ".join(disassm)
|
| 533 |
+
return fulldiss
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
# Regular
|
| 537 |
+
if __name__ == "__main__":
|
| 538 |
+
import argparse
|
| 539 |
+
parser = argparse.ArgumentParser("Pre-process assembly to replace constants and dump")
|
| 540 |
+
parser.add_argument("--bin", required=True)
|
| 541 |
+
parser.add_argument("--func", required=True)
|
| 542 |
+
parser.add_argument("--arch", required=True)
|
| 543 |
+
args = parser.parse_args()
|
| 544 |
+
|
| 545 |
+
if args.arch == "arm32":
|
| 546 |
+
D = DisassemblerARM32(args.bin)
|
| 547 |
+
elif args.arch == "aarch64":
|
| 548 |
+
D = DisassemblerAArch64(args.bin)
|
| 549 |
+
elif args.arch == "x64":
|
| 550 |
+
D = DisassemblerX64(args.bin)
|
| 551 |
+
diss = D.disassemble(args.func)
|
| 552 |
+
print(diss)
|
| 553 |
+
print(D.constants)
|
remend/edit_model.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
removed = ["encoder.layers.0.in_proj_weight", "encoder.layers.0.in_proj_bias", "encoder.layers.0.out_proj_weight", "encoder.layers.0.out_proj_bias", "encoder.layers.0.fc1_weight", "encoder.layers.0.fc1_bias", "encoder.layers.0.fc2_weight", "encoder.layers.0.fc2_bias", "encoder.layers.1.in_proj_weight", "encoder.layers.1.in_proj_bias", "encoder.layers.1.out_proj_weight", "encoder.layers.1.out_proj_bias", "encoder.layers.1.fc1_weight", "encoder.layers.1.fc1_bias", "encoder.layers.1.fc2_weight", "encoder.layers.1.fc2_bias", "encoder.layers.2.in_proj_weight", "encoder.layers.2.in_proj_bias", "encoder.layers.2.out_proj_weight", "encoder.layers.2.out_proj_bias", "encoder.layers.2.fc1_weight", "encoder.layers.2.fc1_bias", "encoder.layers.2.fc2_weight", "encoder.layers.2.fc2_bias", "encoder.layers.3.in_proj_weight", "encoder.layers.3.in_proj_bias", "encoder.layers.3.out_proj_weight", "encoder.layers.3.out_proj_bias", "encoder.layers.3.fc1_weight", "encoder.layers.3.fc1_bias", "encoder.layers.3.fc2_weight", "encoder.layers.3.fc2_bias", "encoder.layers.4.in_proj_weight", "encoder.layers.4.in_proj_bias", "encoder.layers.4.out_proj_weight", "encoder.layers.4.out_proj_bias", "encoder.layers.4.fc1_weight", "encoder.layers.4.fc1_bias", "encoder.layers.4.fc2_weight", "encoder.layers.4.fc2_bias", "encoder.layers.5.in_proj_weight", "encoder.layers.5.in_proj_bias", "encoder.layers.5.out_proj_weight", "encoder.layers.5.out_proj_bias", "encoder.layers.5.fc1_weight", "encoder.layers.5.fc1_bias", "encoder.layers.5.fc2_weight", "encoder.layers.5.fc2_bias"]
|
| 2 |
+
|
| 3 |
+
if __name__ == "__main__":
|
| 4 |
+
import argparse
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
parser = argparse.ArgumentParser("Edit the checkpoint to remove extra dict weights")
|
| 8 |
+
parser.add_argument("-c", "--checkpoint", required=True, help="Input checkpoint")
|
| 9 |
+
parser.add_argument("-e", "--edited", required=True, help="Edited checkpoint")
|
| 10 |
+
args = parser.parse_args()
|
| 11 |
+
|
| 12 |
+
sd = torch.load(args.checkpoint, weights_only=False)
|
| 13 |
+
for k in removed:
|
| 14 |
+
if k in sd['model']:
|
| 15 |
+
del sd['model'][k]
|
| 16 |
+
torch.save(sd, args.edited)
|
remend/eval_generated.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sympy as sp
|
| 2 |
+
import numpy as np
|
| 3 |
+
import warnings
|
| 4 |
+
from sympy.abc import x
|
| 5 |
+
import sys
|
| 6 |
+
import json
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
from .parser import parse_prefix_to_sympy, isint
|
| 10 |
+
|
| 11 |
+
# Ignore sympy lambda warnings.
|
| 12 |
+
warnings.simplefilter("ignore")
|
| 13 |
+
|
| 14 |
+
def percent(a, n):
|
| 15 |
+
return f"{a/n*100:0.1f}%"
|
| 16 |
+
|
| 17 |
+
def do_eval_match(orig_expr, gen_expr):
|
| 18 |
+
try:
|
| 19 |
+
origl = sp.lambdify(x, orig_expr)
|
| 20 |
+
genl = sp.lambdify(x, gen_expr)
|
| 21 |
+
count = 0
|
| 22 |
+
|
| 23 |
+
for v in np.arange(0.2, 1, 0.01):
|
| 24 |
+
o = origl(v)
|
| 25 |
+
g = genl(v)
|
| 26 |
+
if o == float('nan') or o == float('inf'):
|
| 27 |
+
continue
|
| 28 |
+
if g == float('nan') or g == float('inf'):
|
| 29 |
+
continue
|
| 30 |
+
# if type(o) != np.float64 or type(g) != np.float64:
|
| 31 |
+
# print(orig_expr, o, gen_expr, g)
|
| 32 |
+
# return False
|
| 33 |
+
if abs((o-g)/o) > 1e-5:
|
| 34 |
+
return False
|
| 35 |
+
count += 1
|
| 36 |
+
except:
|
| 37 |
+
return False
|
| 38 |
+
return count >= 5
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
import argparse
|
| 42 |
+
parser = argparse.ArgumentParser("Check generated expressions")
|
| 43 |
+
parser.add_argument("-g", required=True, help="Generated expressions file")
|
| 44 |
+
parser.add_argument("-c", required=True, help="Constants file")
|
| 45 |
+
parser.add_argument("-e", required=True, help="Equations file")
|
| 46 |
+
parser.add_argument("-r", required=True, help="Results file")
|
| 47 |
+
args = parser.parse_args()
|
| 48 |
+
|
| 49 |
+
gens = []
|
| 50 |
+
with open(args.g, 'r') as genf, open(args.c) as constf, open(args.e) as eqnf:
|
| 51 |
+
for line in tqdm(genf, desc="Reading file"):
|
| 52 |
+
comps = line.strip().split("\t")
|
| 53 |
+
if line[0] == 'H':
|
| 54 |
+
num = int(comps[0][2:])
|
| 55 |
+
tokens = comps[2].split(" ")
|
| 56 |
+
eqn = next(eqnf)
|
| 57 |
+
const = next(constf)
|
| 58 |
+
const = json.loads(const.strip())
|
| 59 |
+
gens.append((num, tokens, eqn.strip(), const))
|
| 60 |
+
|
| 61 |
+
parsed = []
|
| 62 |
+
matched = []
|
| 63 |
+
results = []
|
| 64 |
+
|
| 65 |
+
for n, toks, eqn, const in tqdm(gens, desc="Evaluating expressions"):
|
| 66 |
+
res = {"id": n, "parsed": False, "matched": False, "orig": "", "gen": ""}
|
| 67 |
+
if "<<unk>>" in toks:
|
| 68 |
+
# Not parsed
|
| 69 |
+
results.append(res)
|
| 70 |
+
continue
|
| 71 |
+
try:
|
| 72 |
+
gen_expr = parse_prefix_to_sympy(toks)
|
| 73 |
+
except Exception as e:
|
| 74 |
+
# Not parsed
|
| 75 |
+
results.append(res)
|
| 76 |
+
continue
|
| 77 |
+
|
| 78 |
+
res["parsed"] = True
|
| 79 |
+
parsed.append(n)
|
| 80 |
+
|
| 81 |
+
gen_expr = gen_expr.subs([(sp.Symbol("k"+c), const[c]) for c in const])
|
| 82 |
+
orig_expr = sp.parse_expr(eqn, local_dict={"x0":x})
|
| 83 |
+
res["orig"] = str(orig_expr)
|
| 84 |
+
res["gen"] = str(gen_expr)
|
| 85 |
+
|
| 86 |
+
if not do_eval_match(orig_expr, gen_expr):
|
| 87 |
+
results.append(res)
|
| 88 |
+
continue
|
| 89 |
+
res["matched"] = True
|
| 90 |
+
matched.append(n)
|
| 91 |
+
results.append(res)
|
| 92 |
+
|
| 93 |
+
with open(args.r, "w") as resf:
|
| 94 |
+
for res in results:
|
| 95 |
+
resf.write("{id} {parsed} {matched} \"{orig}\" \"{gen}\"\n".format(**res))
|
| 96 |
+
resf.write("\n")
|
| 97 |
+
N = len(gens)
|
| 98 |
+
print("Total", N, file=resf)
|
| 99 |
+
print("Parsed", len(parsed), percent(len(parsed), N), file=resf)
|
| 100 |
+
print("Matched", len(matched), percent(len(matched), N), file=resf)
|
remend/experiment.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sympy as sp
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
from .parser import parse_prefix_to_sympy
|
| 5 |
+
|
| 6 |
+
isconst = lambda e: not any(c.is_symbol for c in e.atoms())
|
| 7 |
+
def constfold(expr):
|
| 8 |
+
q = [expr]
|
| 9 |
+
cidx = 0
|
| 10 |
+
subsmap = {}
|
| 11 |
+
constmap = {}
|
| 12 |
+
|
| 13 |
+
while len(q) > 0:
|
| 14 |
+
curr_expr = q.pop(0)
|
| 15 |
+
if isinstance(curr_expr, sp.Number) or isconst(e):
|
| 16 |
+
const_expr = curr_expr.evalf()
|
| 17 |
+
rep = sp.nsimplify(const_expr, [sp.E, sp.pi])
|
| 18 |
+
if isinstance(rep, sp.Integer) or \
|
| 19 |
+
(isinstance(rep, sp.Rational) and rep.q <= 16):
|
| 20 |
+
subsmap[curr_expr] = rep
|
| 21 |
+
else:
|
| 22 |
+
subsmap[curr_expr] = sp.Symbol(f"k{cidx}")
|
| 23 |
+
constmap[f"k{cidx}"] = float(const_expr)
|
| 24 |
+
cidx += 1
|
| 25 |
+
else:
|
| 26 |
+
for child in curr_expr.args:
|
| 27 |
+
q.append(child)
|
| 28 |
+
|
| 29 |
+
return expr.subs(subsmap), constmap
|
| 30 |
+
|
| 31 |
+
def replace_const(expr):
|
| 32 |
+
cidx = 0
|
| 33 |
+
subsmap = {}
|
| 34 |
+
constmap = {}
|
| 35 |
+
|
| 36 |
+
for c in sp.preorder_traversal(expr):
|
| 37 |
+
if isinstance(c, sp.Float):
|
| 38 |
+
rep = sp.nsimplify(c)
|
| 39 |
+
if isinstance(rep, sp.Integer) or \
|
| 40 |
+
(isinstance(rep, sp.Rational) and rep.q <= 16):
|
| 41 |
+
subsmap[c] = rep
|
| 42 |
+
else:
|
| 43 |
+
subsmap[c] = sp.Symbol(f"c{cidx}")
|
| 44 |
+
constmap[f"c{cidx}"] = float(c)
|
| 45 |
+
cidx += 1
|
| 46 |
+
return expr.subs(subsmap), constmap
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
import argparse
|
| 51 |
+
parser = argparse.ArgumentParser("Random experiments")
|
| 52 |
+
parser.add_argument("-f", required=True)
|
| 53 |
+
parser.add_argument("-p", type=float, default=0.1)
|
| 54 |
+
parser.add_argument("-n", type=int, default=20)
|
| 55 |
+
args = parser.parse_args()
|
| 56 |
+
|
| 57 |
+
random.seed(1225)
|
| 58 |
+
|
| 59 |
+
count = 0
|
| 60 |
+
with open(args.f, "r") as f:
|
| 61 |
+
for line in f:
|
| 62 |
+
if random.random() > args.p:
|
| 63 |
+
continue
|
| 64 |
+
|
| 65 |
+
prefl = line.strip().split(" ")
|
| 66 |
+
|
| 67 |
+
orig = parse_prefix_to_sympy(prefl)
|
| 68 |
+
# simp = sp.simplify(expr)
|
| 69 |
+
expr = constfold(orig)
|
| 70 |
+
expr, consts = replace_const(expr)
|
| 71 |
+
print(orig, expr, consts)
|
| 72 |
+
|
| 73 |
+
count += 1
|
| 74 |
+
if count == args.n:
|
| 75 |
+
break
|
remend/find_duplicates.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from Levenshtein import distance
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
import argparse
|
| 7 |
+
parser = argparse.ArgumentParser("Find duplicates in the dataset ASM")
|
| 8 |
+
parser.add_argument("--train", required=True)
|
| 9 |
+
# parser.add_argument("--valid", required=True)
|
| 10 |
+
parser.add_argument("--test", required=True)
|
| 11 |
+
parser.add_argument("--result", required=False)
|
| 12 |
+
parser.add_argument("--distance", action="store_true", default=False)
|
| 13 |
+
args = parser.parse_args()
|
| 14 |
+
|
| 15 |
+
train = []
|
| 16 |
+
train_hash = {}
|
| 17 |
+
# valid = []
|
| 18 |
+
test = []
|
| 19 |
+
with open(args.train, "r") as tf:
|
| 20 |
+
for idx, line in tqdm(enumerate(tf), desc="Read train", leave=False):
|
| 21 |
+
train_hash[hash(line)] = idx
|
| 22 |
+
comps = line.strip().split(" ")
|
| 23 |
+
train.append(comps)
|
| 24 |
+
# with open(args.valid, "r") as tf:
|
| 25 |
+
# for line in tqdm(tf, desc="Read valid", leave=False):
|
| 26 |
+
# valid.append(line.strip().split(" "))
|
| 27 |
+
with open(args.test, "r") as tf:
|
| 28 |
+
for line in tqdm(tf, desc="Read test", leave=False):
|
| 29 |
+
test.append(line)
|
| 30 |
+
|
| 31 |
+
selfcheck = args.train == args.test
|
| 32 |
+
if args.result:
|
| 33 |
+
rf = open(args.result, "w")
|
| 34 |
+
searchdist = args.distance
|
| 35 |
+
else:
|
| 36 |
+
searchdist = False # Dont compute if no result file
|
| 37 |
+
rf = None
|
| 38 |
+
|
| 39 |
+
def reswrite(s):
|
| 40 |
+
if rf:
|
| 41 |
+
rf.write(s)
|
| 42 |
+
|
| 43 |
+
exact = 0
|
| 44 |
+
for i, testline in tqdm(enumerate(test), desc="Test", total=len(test)):
|
| 45 |
+
testl = testline.strip().split(" ")
|
| 46 |
+
htest = hash(testline)
|
| 47 |
+
if htest in train_hash:
|
| 48 |
+
# Found exact match
|
| 49 |
+
j = train_hash[htest]
|
| 50 |
+
if not selfcheck or j != i:
|
| 51 |
+
exact += 1
|
| 52 |
+
reswrite(f"{i} {j} 0 0.0\n")
|
| 53 |
+
continue
|
| 54 |
+
|
| 55 |
+
# If not, then search
|
| 56 |
+
if searchdist:
|
| 57 |
+
minavgdist, mindist, minj = 100, 100, -1
|
| 58 |
+
for j, trainl in enumerate(train):
|
| 59 |
+
if abs(len(trainl) - len(testl)) > 10:
|
| 60 |
+
dist = abs(len(trainl) - len(testl)) * 2 # HACK to speed it up
|
| 61 |
+
else:
|
| 62 |
+
dist = distance(trainl, testl)
|
| 63 |
+
avgdist = dist / (len(trainl) + len(testl))
|
| 64 |
+
if mindist > dist:
|
| 65 |
+
minavgdist, mindist, minj = avgdist, dist, j
|
| 66 |
+
|
| 67 |
+
reswrite(f"{i} {minj} {mindist} {minavgdist}\n")
|
| 68 |
+
|
| 69 |
+
print("Exact duplicates:", exact)
|
remend/implementation.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sympy as sp
|
| 2 |
+
from sympy.codegen import ast
|
| 3 |
+
import itertools as it
|
| 4 |
+
import networkx as nx
|
| 5 |
+
|
| 6 |
+
from .parser import OPERATORS, sympy_to_dag
|
| 7 |
+
from .util import DecodeError
|
| 8 |
+
|
| 9 |
+
def isnum(s):
|
| 10 |
+
try:
|
| 11 |
+
float(s)
|
| 12 |
+
return True
|
| 13 |
+
except ValueError:
|
| 14 |
+
return False
|
| 15 |
+
|
| 16 |
+
class Implementor:
|
| 17 |
+
def __init__(self, expr, constants={}, dtype="double"):
|
| 18 |
+
self.expr = expr
|
| 19 |
+
self.constants = constants
|
| 20 |
+
self.cdtype = dtype
|
| 21 |
+
self.cpf = "lf" if dtype == "double" else "f"
|
| 22 |
+
self.fdtype = "double precision" if dtype == "double" else "real"
|
| 23 |
+
|
| 24 |
+
def implement(self, impl):
|
| 25 |
+
if impl == "dag_c":
|
| 26 |
+
return self.dag_to_c_impl()
|
| 27 |
+
elif impl == "cse_c":
|
| 28 |
+
return self.sympy_cse_c_impl()
|
| 29 |
+
elif impl == "dag_fortran":
|
| 30 |
+
return self.dag_to_fortran_impl()
|
| 31 |
+
elif impl == "cse_fortran":
|
| 32 |
+
return self.sympy_cse_fortran_impl()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def op_c_impl(self, f, children):
|
| 36 |
+
if f == "add":
|
| 37 |
+
return " + ".join(children);
|
| 38 |
+
elif f == "mul":
|
| 39 |
+
return " * ".join(children);
|
| 40 |
+
elif f == "pow":
|
| 41 |
+
assert len(children) == 2
|
| 42 |
+
if self.cdtype == "double":
|
| 43 |
+
return f"pow({children[0]}, {children[1]})"
|
| 44 |
+
else:
|
| 45 |
+
return f"powf({children[0]}, {children[1]})"
|
| 46 |
+
elif f == "ln":
|
| 47 |
+
assert len(children) == 1
|
| 48 |
+
if self.cdtype == "double":
|
| 49 |
+
return f"log({children[0]})"
|
| 50 |
+
else:
|
| 51 |
+
return f"logf({children[0]})"
|
| 52 |
+
else:
|
| 53 |
+
if f in OPERATORS and OPERATORS[f][1] == 1:
|
| 54 |
+
assert len(children) == 1
|
| 55 |
+
if self.cdtype == "double":
|
| 56 |
+
return f"{f}({children[0]})"
|
| 57 |
+
else:
|
| 58 |
+
return f"{f}f({children[0]})"
|
| 59 |
+
else:
|
| 60 |
+
raise DecodeError(f"C impl: operation {f} not handled")
|
| 61 |
+
|
| 62 |
+
def op_f_impl(self, f, children):
|
| 63 |
+
if f == "add":
|
| 64 |
+
j = ")+(".join(children)
|
| 65 |
+
return "(" + j + ")"
|
| 66 |
+
elif f == "mul":
|
| 67 |
+
j = ")*(".join(children)
|
| 68 |
+
return "(" + j + ")"
|
| 69 |
+
elif f == "pow":
|
| 70 |
+
assert len(children) == 2
|
| 71 |
+
return f"({children[0]})**({children[1]})"
|
| 72 |
+
elif f == "ln":
|
| 73 |
+
assert len(children) == 1
|
| 74 |
+
return f"log({children[0]})"
|
| 75 |
+
else:
|
| 76 |
+
if f in OPERATORS and OPERATORS[f][1] == 1:
|
| 77 |
+
assert len(children) == 1
|
| 78 |
+
return f"{f}({children[0]})"
|
| 79 |
+
else:
|
| 80 |
+
raise DecodeError(f"F impl: operation {f} not handled")
|
| 81 |
+
|
| 82 |
+
def full_c_code(self, body):
|
| 83 |
+
pre = f"#include <stdio.h>\n#include <math.h>\n{self.cdtype} myfunc({self.cdtype} x) {{"
|
| 84 |
+
post = f"}}\nint main() {{ {self.cdtype} x; scanf(\"%{self.cpf}\", &x); printf(\"%{self.cpf}\", myfunc(x)); }}"
|
| 85 |
+
return f"{pre}\n{body}\n{post}"
|
| 86 |
+
|
| 87 |
+
def full_f_code(self, body):
|
| 88 |
+
pre = "function myfunc(x) result(y)\nimplicit none\n" + \
|
| 89 |
+
f"{self.fdtype}, intent(in) :: x\n{self.fdtype} :: y, E, pi\n"
|
| 90 |
+
post = "end function myfunc\nprogram main\nimplicit none\n" + \
|
| 91 |
+
f"{self.fdtype} :: x\n{self.fdtype} :: myfunc\n" + \
|
| 92 |
+
"read(*, *) x\nprint *, \"y is:\", myfunc(x)\nend program main"
|
| 93 |
+
return f"{pre}\n{body}\n{post}"
|
| 94 |
+
|
| 95 |
+
def dag_to_c_impl(self):
|
| 96 |
+
dag = sympy_to_dag(self.expr, csuf="F" if self.cdtype == "float" else "")
|
| 97 |
+
cstr = ""
|
| 98 |
+
added_pi, added_E = False, False
|
| 99 |
+
for c in self.constants:
|
| 100 |
+
cstr += f"{self.cdtype} {c} = {self.constants[c]};\n"
|
| 101 |
+
varidx = it.count()
|
| 102 |
+
for node in reversed(list(nx.topological_sort(dag))):
|
| 103 |
+
label = dag.nodes[node]["label"]
|
| 104 |
+
children = [dag.nodes[n]["var"] for n in dag.adj[node]]
|
| 105 |
+
if len(children) == 0:
|
| 106 |
+
if label == "pi":
|
| 107 |
+
if self.cdtype == "float" and not added_pi:
|
| 108 |
+
cstr += "const float pi = 3.14159265F;\n"
|
| 109 |
+
added_pi = True
|
| 110 |
+
else:
|
| 111 |
+
label = "M_PI"
|
| 112 |
+
elif label == "E":
|
| 113 |
+
if self.cdtype == "float" and not added_E:
|
| 114 |
+
cstr += "const float E = 2.71828183F;\n"
|
| 115 |
+
added_E = True
|
| 116 |
+
else:
|
| 117 |
+
label = "M_E"
|
| 118 |
+
dag.nodes[node]["var"] = label
|
| 119 |
+
continue
|
| 120 |
+
varname = f"t{next(varidx)}"
|
| 121 |
+
cexpr = self.op_c_impl(label, children)
|
| 122 |
+
dag.nodes[node]["var"] = varname
|
| 123 |
+
cstr += f"{self.cdtype} {varname} = {cexpr};\n"
|
| 124 |
+
retname = varname
|
| 125 |
+
cstr += f"return {retname};\n"
|
| 126 |
+
return self.full_c_code(cstr)
|
| 127 |
+
|
| 128 |
+
def dag_to_fortran_impl(self):
|
| 129 |
+
csuf = "" if self.fdtype == "real" else "d0"
|
| 130 |
+
dag = sympy_to_dag(self.expr, csuf=csuf)
|
| 131 |
+
varstr = ""
|
| 132 |
+
fstr = "parameter E = 2.71828183\nparameter pi = 3.14159265\n"
|
| 133 |
+
for c in self.constants:
|
| 134 |
+
varstr += f"{self.fdtype} :: {c}\n"
|
| 135 |
+
fstr += f"parameter {c} = {self.constants[c]}{csuf}\n"
|
| 136 |
+
varidx = it.count()
|
| 137 |
+
allvars = []
|
| 138 |
+
for node in reversed(list(nx.topological_sort(dag))):
|
| 139 |
+
label = dag.nodes[node]["label"]
|
| 140 |
+
children = [dag.nodes[n]["var"] for n in dag.adj[node]]
|
| 141 |
+
if len(children) == 0:
|
| 142 |
+
dag.nodes[node]["var"] = label
|
| 143 |
+
continue
|
| 144 |
+
varname = f"t{next(varidx)}"
|
| 145 |
+
fexpr = self.op_f_impl(label, children)
|
| 146 |
+
dag.nodes[node]["var"] = varname
|
| 147 |
+
fstr += f"{varname} = {fexpr}\n"
|
| 148 |
+
retname = varname
|
| 149 |
+
varstr += f"{self.fdtype} :: {varname}\n"
|
| 150 |
+
fstr += f"y = {retname};\n"
|
| 151 |
+
fstr = varstr + "\n" + fstr
|
| 152 |
+
return self.full_f_code(fstr)
|
| 153 |
+
|
| 154 |
+
def sympy_cse_c_impl(self):
|
| 155 |
+
if self.cdtype == "float":
|
| 156 |
+
extraargs = {
|
| 157 |
+
"type_aliases": {ast.real: ast.float32},
|
| 158 |
+
"math_macros": {},
|
| 159 |
+
}
|
| 160 |
+
else:
|
| 161 |
+
extraargs = {}
|
| 162 |
+
cstr = ""
|
| 163 |
+
for c in self.constants:
|
| 164 |
+
cstr += f"{self.cdtype} {c} = {self.constants[c]};\n"
|
| 165 |
+
xvars, xpr = sp.cse(self.expr)
|
| 166 |
+
for vname, vxpr in xvars:
|
| 167 |
+
code = sp.ccode(vxpr, assign_to=vname.name, **extraargs)
|
| 168 |
+
cstr += f"{self.cdtype} {vname.name}; {code};\n"
|
| 169 |
+
assert len(xpr) == 1
|
| 170 |
+
code = sp.ccode(xpr[0], assign_to="y", **extraargs)
|
| 171 |
+
cstr += f"{self.cdtype} y; {code}; return y;\n"
|
| 172 |
+
return self.full_c_code(cstr)
|
| 173 |
+
|
| 174 |
+
def sympy_cse_fortran_impl(self):
|
| 175 |
+
csuf = "" if self.fdtype == "real" else "d0"
|
| 176 |
+
varstr = ""
|
| 177 |
+
fstr = ""
|
| 178 |
+
for c in self.constants:
|
| 179 |
+
varstr += f"{self.fdtype} :: {c}\n"
|
| 180 |
+
fstr += f"parameter {c} = {self.constants[c]}{csuf}\n"
|
| 181 |
+
xvars, xpr = sp.cse(self.expr)
|
| 182 |
+
for vname, vxpr in xvars:
|
| 183 |
+
varstr += f"{self.fdtype} :: {vname.name}\n"
|
| 184 |
+
fstr += sp.fcode(vxpr, assign_to=vname.name, standard=95, source_format="free") + "\n"
|
| 185 |
+
assert len(xpr) == 1
|
| 186 |
+
fstr += sp.fcode(xpr[0], assign_to="y", standard=95, source_format="free") + "\n"
|
| 187 |
+
fstr = varstr + "\n" + fstr
|
| 188 |
+
if self.fdtype == "real":
|
| 189 |
+
# Hack to fix sympy generation
|
| 190 |
+
fstr = fstr.replace("d0", "")
|
| 191 |
+
return self.full_f_code(fstr)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# For testing only
|
| 196 |
+
if __name__ == "__main__":
|
| 197 |
+
from .parser import parse_prefix_to_sympy, sympy_to_dag
|
| 198 |
+
|
| 199 |
+
prefs = "add mul div INT+ 1 INT+ 5 x mul div INT+ 1 INT+ 5 mul x tan pow x INT+ 2".split(" ")
|
| 200 |
+
exp = parse_prefix_to_sympy(prefs)
|
| 201 |
+
impl = Implementor(exp, dtype="float")
|
| 202 |
+
|
| 203 |
+
print("DAG C:")
|
| 204 |
+
print(impl.dag_to_c_impl())
|
| 205 |
+
print("DAG Fortran:")
|
| 206 |
+
print(impl.dag_to_fortran_impl())
|
| 207 |
+
print("CSE C:")
|
| 208 |
+
print(impl.sympy_cse_c_impl())
|
| 209 |
+
print("CSE Fortran:")
|
| 210 |
+
print(impl.sympy_cse_fortran_impl())
|
remend/parser.py
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sympy as sp
|
| 2 |
+
import networkx as nx
|
| 3 |
+
import itertools as it
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
from .util import DecodeError, sympy_expr_ok
|
| 7 |
+
|
| 8 |
+
OPERATORS = {
|
| 9 |
+
# Elementary functions
|
| 10 |
+
'add': (lambda a,b: a+b, 2),
|
| 11 |
+
'sub': (lambda a,b: a-b, 2),
|
| 12 |
+
'mul': (lambda a,b: a*b, 2),
|
| 13 |
+
'div': (lambda a,b: a/b, 2),
|
| 14 |
+
'pow': (lambda a,b: a**b, 2),
|
| 15 |
+
# 'inv': (lambda a: 1/a, 1),
|
| 16 |
+
# 'pow2': (lambda a: a**2, 1),
|
| 17 |
+
# 'pow3': (lambda a: a**3, 1),
|
| 18 |
+
# 'pow4': (lambda a: a**4, 1),
|
| 19 |
+
# 'pow5': (lambda a: a**5, 1),
|
| 20 |
+
'sqrt': (lambda a: sp.sqrt(a), 1),
|
| 21 |
+
'exp': (lambda a: sp.exp(a), 1),
|
| 22 |
+
'ln': (lambda a: sp.ln(a), 1),
|
| 23 |
+
# 'abs': (lambda a: sp.abs(a), 1),
|
| 24 |
+
# 'sign': (lambda a: sp.sign(a), 1),
|
| 25 |
+
# Trigonometric Functions
|
| 26 |
+
'sin': (lambda a: sp.sin(a), 1),
|
| 27 |
+
'cos': (lambda a: sp.cos(a), 1),
|
| 28 |
+
'tan': (lambda a: sp.tan(a), 1),
|
| 29 |
+
'cot': (lambda a: sp.cot(a), 1),
|
| 30 |
+
'sec': (lambda a: sp.sec(a), 1),
|
| 31 |
+
'csc': (lambda a: sp.csc(a), 1),
|
| 32 |
+
# Trigonometric Inverses
|
| 33 |
+
'asin': (lambda a: sp.asin(a), 1),
|
| 34 |
+
'acos': (lambda a: sp.acos(a), 1),
|
| 35 |
+
'atan': (lambda a: sp.atan(a), 1),
|
| 36 |
+
'acot': (lambda a: sp.acot(a), 1),
|
| 37 |
+
'asec': (lambda a: sp.asec(a), 1),
|
| 38 |
+
'acsc': (lambda a: sp.acsc(a), 1),
|
| 39 |
+
# Hyperbolic
|
| 40 |
+
# 'sinh': (lambda a: sp.sinh(a), 1),
|
| 41 |
+
# 'cosh': (lambda a: sp.cosh(a), 1),
|
| 42 |
+
# 'tanh': (lambda a: sp.tanh(a), 1),
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
CONSTANTS = {
|
| 46 |
+
'E': sp.E,
|
| 47 |
+
'pi': sp.pi,
|
| 48 |
+
'0': 0,
|
| 49 |
+
'1': 1,
|
| 50 |
+
'2': 2,
|
| 51 |
+
'3': 3,
|
| 52 |
+
'4': 4,
|
| 53 |
+
'5': 5,
|
| 54 |
+
'6': 6,
|
| 55 |
+
'7': 7,
|
| 56 |
+
'8': 8,
|
| 57 |
+
'9': 9,
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
VARIABLES = {
|
| 61 |
+
'x': sp.Symbol('x'),
|
| 62 |
+
'x0': sp.Symbol('x0'),
|
| 63 |
+
'x1': sp.Symbol('x1'),
|
| 64 |
+
|
| 65 |
+
'c0': sp.Symbol('c0'),
|
| 66 |
+
'c1': sp.Symbol('c1'),
|
| 67 |
+
'c2': sp.Symbol('c2'),
|
| 68 |
+
'c3': sp.Symbol('c3'),
|
| 69 |
+
'c4': sp.Symbol('c4'),
|
| 70 |
+
'c5': sp.Symbol('c5'),
|
| 71 |
+
'c6': sp.Symbol('c6'),
|
| 72 |
+
'c7': sp.Symbol('c7'),
|
| 73 |
+
'c8': sp.Symbol('c8'),
|
| 74 |
+
'c9': sp.Symbol('c9'),
|
| 75 |
+
'c10': sp.Symbol('c10'),
|
| 76 |
+
|
| 77 |
+
'k0': sp.Symbol('k0'),
|
| 78 |
+
'k1': sp.Symbol('k1'),
|
| 79 |
+
'k2': sp.Symbol('k2'),
|
| 80 |
+
'k3': sp.Symbol('k3'),
|
| 81 |
+
# 'y': sp.Symbol('y'),
|
| 82 |
+
# 'z': sp.Symbol('z')
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
FUNC_TO_OP = {
|
| 86 |
+
sp.Add: 'add',
|
| 87 |
+
sp.Mul: 'mul',
|
| 88 |
+
sp.Pow: 'pow',
|
| 89 |
+
|
| 90 |
+
sp.log: 'ln',
|
| 91 |
+
sp.sqrt: 'sqrt',
|
| 92 |
+
sp.exp: 'exp',
|
| 93 |
+
sp.Abs: 'abs',
|
| 94 |
+
# 'abs': (lambda a: sp.abs(a), 1),
|
| 95 |
+
# 'sign': (lambda a: sp.sign(a), 1),
|
| 96 |
+
# Trigonometric Functions
|
| 97 |
+
sp.sin: 'sin',
|
| 98 |
+
sp.cos: 'cos',
|
| 99 |
+
sp.tan: 'tan',
|
| 100 |
+
sp.cot: 'cot',
|
| 101 |
+
sp.sec: 'sec',
|
| 102 |
+
sp.csc: 'csc',
|
| 103 |
+
# Trigonometric Inverses
|
| 104 |
+
sp.asin: 'asin',
|
| 105 |
+
sp.acos: 'acos',
|
| 106 |
+
sp.atan: 'atan',
|
| 107 |
+
sp.acot: 'acot',
|
| 108 |
+
sp.asec: 'asec',
|
| 109 |
+
sp.acsc: 'acsc',
|
| 110 |
+
# Hyperbolic
|
| 111 |
+
# sp.cosh: 'cosh',
|
| 112 |
+
# sp.sinh: 'sinh',
|
| 113 |
+
# sp.tanh: 'tanh'
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
def sympy_func_to_op(f):
|
| 117 |
+
if f in FUNC_TO_OP:
|
| 118 |
+
return FUNC_TO_OP[f]
|
| 119 |
+
else:
|
| 120 |
+
raise DecodeError(f"Op not found {f}")
|
| 121 |
+
return str(f)
|
| 122 |
+
|
| 123 |
+
def isint(s):
|
| 124 |
+
try:
|
| 125 |
+
int(s)
|
| 126 |
+
return True
|
| 127 |
+
except ValueError:
|
| 128 |
+
return False
|
| 129 |
+
|
| 130 |
+
def reverse_iter_prefix(prefs):
|
| 131 |
+
n = len(prefs) - 1
|
| 132 |
+
# currnum = 0
|
| 133 |
+
# currpow = 1
|
| 134 |
+
currnum = []
|
| 135 |
+
while n >= 0:
|
| 136 |
+
if isint(prefs[n]) or prefs[n] in ["e", "+", "-", "."]:
|
| 137 |
+
currnum += prefs[n]
|
| 138 |
+
# currnum += currpow * int(prefs[n])
|
| 139 |
+
# currpow *= 10
|
| 140 |
+
elif prefs[n][:3] == "INT":
|
| 141 |
+
parsedint = int("".join(reversed(currnum)))
|
| 142 |
+
if prefs[n][3] == "+":
|
| 143 |
+
yield parsedint
|
| 144 |
+
else:
|
| 145 |
+
yield -parsedint
|
| 146 |
+
currnum = []
|
| 147 |
+
# currpow = 1
|
| 148 |
+
elif prefs[n][:5] == "FLOAT":
|
| 149 |
+
parsedfloat = float("".join(reversed(currnum)))
|
| 150 |
+
if prefs[n][5] == "+":
|
| 151 |
+
yield parsedfloat
|
| 152 |
+
else:
|
| 153 |
+
yield -parsedfloat
|
| 154 |
+
currnum = []
|
| 155 |
+
else:
|
| 156 |
+
yield prefs[n]
|
| 157 |
+
n -= 1
|
| 158 |
+
|
| 159 |
+
def parse_prefix_to_sympy(prefs):
|
| 160 |
+
stack = []
|
| 161 |
+
for val in reverse_iter_prefix(prefs):
|
| 162 |
+
# print(stack, val)
|
| 163 |
+
if val in OPERATORS:
|
| 164 |
+
spop, numops = OPERATORS[val]
|
| 165 |
+
operands = [stack.pop() for i in range(numops)]
|
| 166 |
+
expr = spop(*operands)
|
| 167 |
+
stack.append(expr)
|
| 168 |
+
elif val in CONSTANTS:
|
| 169 |
+
stack.append(CONSTANTS[val])
|
| 170 |
+
elif val in VARIABLES:
|
| 171 |
+
stack.append(VARIABLES[val])
|
| 172 |
+
elif type(val) == int or type(val) == float:
|
| 173 |
+
stack.append(val)
|
| 174 |
+
elif val == "(" or val == ")":
|
| 175 |
+
# Simply ignore brackets
|
| 176 |
+
continue
|
| 177 |
+
else:
|
| 178 |
+
raise DecodeError(f"{val} invalid")
|
| 179 |
+
|
| 180 |
+
if len(stack) != 1:
|
| 181 |
+
raise DecodeError(f"Stack not empty, invalid expression: {prefs} || {stack}")
|
| 182 |
+
expr = stack.pop()
|
| 183 |
+
if not sympy_expr_ok(expr):
|
| 184 |
+
raise DecodeError("Complex or infinite expression")
|
| 185 |
+
return expr
|
| 186 |
+
|
| 187 |
+
def parse_postfix_to_sympy(prefs):
|
| 188 |
+
stack = []
|
| 189 |
+
postfix = reversed(list(reverse_iter_prefix(prefs)))
|
| 190 |
+
for val in postfix:
|
| 191 |
+
if val in OPERATORS:
|
| 192 |
+
spop, numops = OPERATORS[val]
|
| 193 |
+
operands = [stack.pop() for i in range(numops)]
|
| 194 |
+
expr = spop(*operands)
|
| 195 |
+
stack.append(expr)
|
| 196 |
+
elif val in CONSTANTS:
|
| 197 |
+
stack.append(CONSTANTS[val])
|
| 198 |
+
elif val in VARIABLES:
|
| 199 |
+
stack.append(VARIABLES[val])
|
| 200 |
+
elif type(val) == int or type(val) == float:
|
| 201 |
+
stack.append(val)
|
| 202 |
+
elif val == "(" or val == ")":
|
| 203 |
+
# Simply ignore brackets
|
| 204 |
+
continue
|
| 205 |
+
else:
|
| 206 |
+
raise DecodeError(f"{val} invalid")
|
| 207 |
+
|
| 208 |
+
if len(stack) != 1:
|
| 209 |
+
raise DecodeError(f"Stack not empty, invalid expression: {prefs} || {stack}")
|
| 210 |
+
expr = stack.pop()
|
| 211 |
+
if not sympy_expr_ok(expr):
|
| 212 |
+
raise DecodeError("Complex or infinite expression")
|
| 213 |
+
return expr
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def parse_prefix_to_tree(prefs):
|
| 217 |
+
tree = nx.DiGraph()
|
| 218 |
+
stack = []
|
| 219 |
+
newidx = len(prefs)
|
| 220 |
+
for nidx, val in enumerate(reverse_iter_prefix(prefs)):
|
| 221 |
+
tree.add_node(nidx, label=val)
|
| 222 |
+
if val in OPERATORS:
|
| 223 |
+
_, numops = OPERATORS[val]
|
| 224 |
+
childs = [stack.pop() for i in range(numops)]
|
| 225 |
+
if val in {"pow", "sub", "div"}:
|
| 226 |
+
# Ordered children
|
| 227 |
+
tree.add_node(newidx, label="lhs")
|
| 228 |
+
tree.add_node(newidx+1, label="rhs")
|
| 229 |
+
tree.add_edge(nidx, newidx)
|
| 230 |
+
tree.add_edge(nidx, newidx+1)
|
| 231 |
+
tree.add_edge(newidx, childs[0])
|
| 232 |
+
tree.add_edge(newidx+1, childs[1])
|
| 233 |
+
newidx += 2
|
| 234 |
+
else:
|
| 235 |
+
for c in childs:
|
| 236 |
+
tree.add_edge(nidx, c)
|
| 237 |
+
elif val in CONSTANTS or val in VARIABLES or type(val) == int:
|
| 238 |
+
pass
|
| 239 |
+
else:
|
| 240 |
+
raise DecodeError(f"Val {val} invalid")
|
| 241 |
+
stack.append(nidx)
|
| 242 |
+
|
| 243 |
+
if len(stack) != 1:
|
| 244 |
+
raise DecodeError(f"Stack not empty, invalid expression: {prefs} || {stack}")
|
| 245 |
+
|
| 246 |
+
return tree, stack.pop() # Root node
|
| 247 |
+
|
| 248 |
+
def sympy_to_dag(expression, csuf=""):
|
| 249 |
+
dag = nx.DiGraph()
|
| 250 |
+
seen = {}
|
| 251 |
+
nitr = it.count()
|
| 252 |
+
|
| 253 |
+
def _dfs(node):
|
| 254 |
+
children = []
|
| 255 |
+
for child in node.args:
|
| 256 |
+
if child in seen:
|
| 257 |
+
cid = seen[child]
|
| 258 |
+
else:
|
| 259 |
+
cid = _dfs(child)
|
| 260 |
+
children.append(cid)
|
| 261 |
+
|
| 262 |
+
nid = next(nitr)
|
| 263 |
+
dag.add_node(nid, expr=node)
|
| 264 |
+
seen[node] = nid
|
| 265 |
+
for cid in children:
|
| 266 |
+
dag.add_edge(nid, cid)
|
| 267 |
+
return nid
|
| 268 |
+
|
| 269 |
+
_dfs(expression)
|
| 270 |
+
for node in dag.nodes:
|
| 271 |
+
if len(dag.adj[node]) == 0:
|
| 272 |
+
e = dag.nodes[node]["expr"]
|
| 273 |
+
if isinstance(e, sp.Integer):
|
| 274 |
+
dag.nodes[node]["label"] = f"{e}.0{csuf}"
|
| 275 |
+
elif isinstance(e, sp.Rational):
|
| 276 |
+
dag.nodes[node]["label"] = f"{e.p}.0{csuf}/{e.q}.0{csuf}"
|
| 277 |
+
elif isinstance(e, sp.Float):
|
| 278 |
+
dag.nodes[node]["label"] = f"{float(e)}{csuf}"
|
| 279 |
+
else:
|
| 280 |
+
dag.nodes[node]["label"] = str(e)
|
| 281 |
+
else:
|
| 282 |
+
dag.nodes[node]["label"] = sympy_func_to_op(dag.nodes[node]["expr"].func)
|
| 283 |
+
|
| 284 |
+
return dag
|
| 285 |
+
|
| 286 |
+
def sympy_to_prefix(expr):
|
| 287 |
+
trav = []
|
| 288 |
+
|
| 289 |
+
def _pre(node):
|
| 290 |
+
nonlocal trav
|
| 291 |
+
if isinstance(node, sp.Rational):
|
| 292 |
+
if node.q != 1:
|
| 293 |
+
trav.append("div")
|
| 294 |
+
_pre(node.p)
|
| 295 |
+
_pre(node.q)
|
| 296 |
+
else:
|
| 297 |
+
_pre(node.p)
|
| 298 |
+
elif isinstance(node, sp.Integer) or isinstance(node, int):
|
| 299 |
+
v = int(node)
|
| 300 |
+
if v >= 0:
|
| 301 |
+
trav.append("INT+")
|
| 302 |
+
trav.extend(list(str(v)))
|
| 303 |
+
else:
|
| 304 |
+
trav.append("INT-")
|
| 305 |
+
trav.extend(list(str(-v)))
|
| 306 |
+
elif isinstance(node, sp.Symbol):
|
| 307 |
+
trav.append(str(node))
|
| 308 |
+
elif isinstance(node, sp.Mul):
|
| 309 |
+
mulargs = []
|
| 310 |
+
divargs = []
|
| 311 |
+
children = node.args
|
| 312 |
+
for child in children:
|
| 313 |
+
if isinstance(child, sp.Pow) and \
|
| 314 |
+
isinstance(child.args[1], sp.Integer) and child.args[1] == -1:
|
| 315 |
+
divargs.append(child.args[0])
|
| 316 |
+
else:
|
| 317 |
+
mulargs.append(child)
|
| 318 |
+
if len(divargs) > 0:
|
| 319 |
+
trav.append("div")
|
| 320 |
+
if len(mulargs) == 0:
|
| 321 |
+
trav.append("INT+")
|
| 322 |
+
trav.append("1")
|
| 323 |
+
# Insert numerator
|
| 324 |
+
for i, child in enumerate(mulargs):
|
| 325 |
+
if i < len(mulargs) - 1:
|
| 326 |
+
trav.append("mul")
|
| 327 |
+
_pre(child)
|
| 328 |
+
# Insert denominator
|
| 329 |
+
for i, child in enumerate(divargs):
|
| 330 |
+
if i < len(divargs) - 1:
|
| 331 |
+
trav.append("mul")
|
| 332 |
+
_pre(child)
|
| 333 |
+
elif isinstance(node, sp.Add):
|
| 334 |
+
addargs = []
|
| 335 |
+
subargs = []
|
| 336 |
+
children = node.args
|
| 337 |
+
for child in children:
|
| 338 |
+
if isinstance(child, sp.Mul) and len(child.args) == 2 and \
|
| 339 |
+
isinstance(child.args[1], sp.Integer) and child.args[1] == -1:
|
| 340 |
+
subargs.append(child.args[0])
|
| 341 |
+
elif isinstance(child, sp.Mul) and len(child.args) == 2 and \
|
| 342 |
+
isinstance(child.args[0], sp.Integer) and child.args[0] == -1:
|
| 343 |
+
subargs.append(child.args[1])
|
| 344 |
+
else:
|
| 345 |
+
addargs.append(child)
|
| 346 |
+
if len(subargs) > 0:
|
| 347 |
+
trav.append("sub")
|
| 348 |
+
if len(addargs) == 0:
|
| 349 |
+
trav.append("INT+")
|
| 350 |
+
trav.append("0")
|
| 351 |
+
# Insert numerator
|
| 352 |
+
for i, child in enumerate(addargs):
|
| 353 |
+
if i < len(addargs) - 1:
|
| 354 |
+
trav.append("add")
|
| 355 |
+
_pre(child)
|
| 356 |
+
# Insert denominator
|
| 357 |
+
for i, child in enumerate(subargs):
|
| 358 |
+
if i < len(subargs) - 1:
|
| 359 |
+
trav.append("add")
|
| 360 |
+
_pre(child)
|
| 361 |
+
elif isinstance(node, sp.Float):
|
| 362 |
+
rep = sp.nsimplify(node, tolerance=1e-7)
|
| 363 |
+
if isinstance(rep, sp.Integer):
|
| 364 |
+
_pre(rep)
|
| 365 |
+
elif isinstance(rep, sp.Rational) and rep.q <= 16:
|
| 366 |
+
_pre(rep)
|
| 367 |
+
else:
|
| 368 |
+
raise DecodeError(f"Float {node} encountered while generating")
|
| 369 |
+
# trav.append(str(node))
|
| 370 |
+
elif node == sp.E or node == sp.pi:
|
| 371 |
+
# Transcendental constants
|
| 372 |
+
trav.append(str(node))
|
| 373 |
+
else:
|
| 374 |
+
op = sympy_func_to_op(node.func)
|
| 375 |
+
children = node.args
|
| 376 |
+
for i, child in enumerate(children):
|
| 377 |
+
# Insert op repeatedly to maintain binary tree
|
| 378 |
+
if i == 0 or i < len(children) - 1:
|
| 379 |
+
trav.append(op)
|
| 380 |
+
_pre(child)
|
| 381 |
+
_pre(expr)
|
| 382 |
+
return trav
|
| 383 |
+
|
| 384 |
+
def constant_fold(expr):
|
| 385 |
+
q = [expr]
|
| 386 |
+
cidx = 0
|
| 387 |
+
subsmap = {}
|
| 388 |
+
constmap = {}
|
| 389 |
+
|
| 390 |
+
isconst = lambda e: not any(c.is_symbol for c in e.atoms())
|
| 391 |
+
|
| 392 |
+
while len(q) > 0:
|
| 393 |
+
curr_expr = q.pop(0)
|
| 394 |
+
if isinstance(curr_expr, sp.Number) or isconst(curr_expr):
|
| 395 |
+
const_expr = curr_expr.evalf()
|
| 396 |
+
rep = sp.nsimplify(const_expr, [sp.E, sp.pi], tolerance=1e-7)
|
| 397 |
+
if isinstance(rep, sp.Integer) or \
|
| 398 |
+
(isinstance(rep, sp.Rational) and rep.q <= 16) or \
|
| 399 |
+
rep == sp.E or rep == sp.pi:
|
| 400 |
+
subsmap[curr_expr] = rep
|
| 401 |
+
else:
|
| 402 |
+
val = float(const_expr)
|
| 403 |
+
found = False
|
| 404 |
+
for c in constmap:
|
| 405 |
+
if abs(val - constmap[c]) < 1e-7:
|
| 406 |
+
subsmap[curr_expr] = sp.Symbol(c)
|
| 407 |
+
found = True
|
| 408 |
+
elif abs(1/val - constmap[c]) < 1e-7:
|
| 409 |
+
subsmap[curr_expr] = 1/sp.Symbol(c)
|
| 410 |
+
found = True
|
| 411 |
+
elif abs(-val - constmap[c]) < 1e-7:
|
| 412 |
+
subsmap[curr_expr] = -sp.Symbol(c)
|
| 413 |
+
found = True
|
| 414 |
+
elif abs(-1/val - constmap[c]) < 1e-7:
|
| 415 |
+
subsmap[curr_expr] = -1/sp.Symbol(c)
|
| 416 |
+
found = True
|
| 417 |
+
if not found:
|
| 418 |
+
subsmap[curr_expr] = sp.Symbol(f"k{cidx}")
|
| 419 |
+
constmap[f"k{cidx}"] = val
|
| 420 |
+
cidx += 1
|
| 421 |
+
else:
|
| 422 |
+
for child in curr_expr.args:
|
| 423 |
+
q.append(child)
|
| 424 |
+
|
| 425 |
+
return expr.subs(subsmap), constmap
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
# For testing only
|
| 429 |
+
if __name__ == "__main__":
|
| 430 |
+
prefs = "add mul INT- 1 x mul pow ln INT+ 4 INT- 1 add x mul INT- 1 pow x INT+ 5".split(" ")
|
| 431 |
+
exp = parse_prefix_to_sympy(prefs)
|
| 432 |
+
exp = sp.simplify(exp)
|
| 433 |
+
print(exp)
|
| 434 |
+
print(constant_fold(exp))
|
| 435 |
+
|
| 436 |
+
# prefs = "mul x mul pow cos INT+ 4 INT- 3 pow ln INT+ 3 INT- 6".split(" ")
|
| 437 |
+
# exp = parse_prefix_to_sympy(prefs)
|
| 438 |
+
# print(exp)
|
| 439 |
+
# dag = sympy_to_dag(exp)
|
| 440 |
+
|
| 441 |
+
# exp = sp.parse_expr("(((((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0))))) * k0) - (-((x0) + (x0)))) / (-((x0) + (x0)))) * ((-((((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0))))) * k0) - ((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0))))))) * ((((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0))))) * k0) - ((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0)))))))", evaluate=False)
|
| 442 |
+
# # print(sympy_to_prefix(exp))
|
| 443 |
+
|
| 444 |
+
# simp = sp.simplify(exp)
|
| 445 |
+
# pre = sympy_to_prefix(simp)
|
| 446 |
+
# print(pre)
|
| 447 |
+
# repars = parse_prefix_to_sympy(pre)
|
| 448 |
+
# print(simp)
|
| 449 |
+
# print(repars)
|
remend/plot_loss.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from matplotlib import pyplot as plt
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
import argparse
|
| 6 |
+
parser = argparse.ArgumentParser("Plot loss for the training log")
|
| 7 |
+
parser.add_argument("-t", "--trainlog", required=True, help="Training log file")
|
| 8 |
+
parser.add_argument("-l", "--loss", help="Loss plot to save (optional)")
|
| 9 |
+
parser.add_argument("--log-scale", default=False, action="store_true", help="Log scale")
|
| 10 |
+
parser.add_argument("-P", "--no-plot", default=True, action="store_false", help="Don't open matplotlib figure")
|
| 11 |
+
args = parser.parse_args()
|
| 12 |
+
|
| 13 |
+
train_inner_upd, train_inner_loss = [], []
|
| 14 |
+
train_upd, train_loss = [], []
|
| 15 |
+
val_upd, val_loss = [], []
|
| 16 |
+
|
| 17 |
+
with open(args.trainlog, "r") as tl:
|
| 18 |
+
for line in tl:
|
| 19 |
+
# Filter out json
|
| 20 |
+
if line[0] != "{":
|
| 21 |
+
continue
|
| 22 |
+
try:
|
| 23 |
+
data = json.loads(line.strip())
|
| 24 |
+
except:
|
| 25 |
+
continue
|
| 26 |
+
if "loss" in data:
|
| 27 |
+
loss = float(data["loss"])
|
| 28 |
+
upd = int(data["num_updates"])
|
| 29 |
+
if len(train_inner_upd) == 0 or train_inner_upd[-1] < upd:
|
| 30 |
+
train_inner_upd.append(upd)
|
| 31 |
+
train_inner_loss.append(loss)
|
| 32 |
+
if "valid_loss" in data:
|
| 33 |
+
loss = float(data["valid_loss"])
|
| 34 |
+
upd = int(data["valid_num_updates"])
|
| 35 |
+
if len(val_upd) == 0 or val_upd[-1] < upd:
|
| 36 |
+
val_upd.append(upd)
|
| 37 |
+
val_loss.append(loss)
|
| 38 |
+
if "train_loss" in data:
|
| 39 |
+
loss = float(data["train_loss"])
|
| 40 |
+
upd = int(data["train_num_updates"])
|
| 41 |
+
if len(train_upd) == 0 or train_upd[-1] < upd:
|
| 42 |
+
train_upd.append(upd)
|
| 43 |
+
train_loss.append(loss)
|
| 44 |
+
|
| 45 |
+
plt.figure()
|
| 46 |
+
plt.plot(train_upd, train_loss, "r")
|
| 47 |
+
plt.plot(val_upd, val_loss, "b")
|
| 48 |
+
if len(train_inner_upd) > 0:
|
| 49 |
+
plt.plot(train_inner_upd, train_inner_loss, "r", alpha=0.3)
|
| 50 |
+
plt.legend(["train", "valid"])
|
| 51 |
+
if args.log_scale:
|
| 52 |
+
plt.yscale("log")
|
| 53 |
+
elif min(min(train_loss), min(val_loss)) < 1:
|
| 54 |
+
plt.ylim((0, 1))
|
| 55 |
+
plt.xlabel("Updates")
|
| 56 |
+
plt.ylabel("Loss")
|
| 57 |
+
if args.loss:
|
| 58 |
+
plt.savefig(args.loss)
|
| 59 |
+
if args.no_plot:
|
| 60 |
+
plt.show()
|
remend/preprocess_remaqe.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import itertools as it
|
| 5 |
+
import sympy as sp
|
| 6 |
+
|
| 7 |
+
from .disassemble import DisassemblerARM32
|
| 8 |
+
from .parser import sympy_to_prefix, isint
|
| 9 |
+
|
| 10 |
+
def match_constants(exprconst, asmconst, constsym, eps=1e-5):
|
| 11 |
+
def _close(a, b):
|
| 12 |
+
return abs(a - b) <= eps
|
| 13 |
+
mapping = {}
|
| 14 |
+
mapped = set()
|
| 15 |
+
|
| 16 |
+
for ec in exprconst:
|
| 17 |
+
ecf = float(exprconst[ec])
|
| 18 |
+
ecsym = constsym[ec]
|
| 19 |
+
if abs(ecf) < eps:
|
| 20 |
+
continue
|
| 21 |
+
for ac in asmconst:
|
| 22 |
+
acf = asmconst[ac]
|
| 23 |
+
acsym = constsym[ac]
|
| 24 |
+
if _close(acf, ecf):
|
| 25 |
+
mapping[ecsym] = acsym
|
| 26 |
+
mapped.add(ec)
|
| 27 |
+
break
|
| 28 |
+
if _close(acf, 1/ecf):
|
| 29 |
+
mapping[ecsym] = 1/acsym
|
| 30 |
+
mapped.add(ec)
|
| 31 |
+
break
|
| 32 |
+
if _close(acf, -ecf):
|
| 33 |
+
mapping[ecsym] = -acsym
|
| 34 |
+
mapped.add(ec)
|
| 35 |
+
break
|
| 36 |
+
return mapping, mapped
|
| 37 |
+
|
| 38 |
+
def replace_naming(pref):
|
| 39 |
+
ret = []
|
| 40 |
+
for p in pref:
|
| 41 |
+
if p == "x0":
|
| 42 |
+
ret.append("x")
|
| 43 |
+
elif p[0] == "c" and isint(p[1:]):
|
| 44 |
+
# Constant
|
| 45 |
+
ret.append("k"+p[1:])
|
| 46 |
+
else:
|
| 47 |
+
ret.append(p)
|
| 48 |
+
return ret
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__":
|
| 52 |
+
import argparse
|
| 53 |
+
parser = argparse.ArgumentParser("Pre-process assembly to replace constants and dump")
|
| 54 |
+
parser.add_argument("--list", required=True)
|
| 55 |
+
parser.add_argument("--prefix", required=True)
|
| 56 |
+
args = parser.parse_args()
|
| 57 |
+
|
| 58 |
+
with open(args.list, "r") as f:
|
| 59 |
+
mdllist = list(f)
|
| 60 |
+
opts = ["O0", "O1", "O2", "O3"]
|
| 61 |
+
|
| 62 |
+
asmf = open(args.prefix + ".asm", "w")
|
| 63 |
+
eqnf = open(args.prefix + ".eqn", "w")
|
| 64 |
+
constf = open(args.prefix + ".const.jsonl", "w")
|
| 65 |
+
|
| 66 |
+
basedir = os.path.dirname(args.list)
|
| 67 |
+
for mdl in tqdm(mdllist):
|
| 68 |
+
mdl = mdl.strip()
|
| 69 |
+
mdlname = os.path.basename(mdl)
|
| 70 |
+
with open(os.path.join(basedir, mdl, "expressions.json")) as f:
|
| 71 |
+
expressions = json.load(f)
|
| 72 |
+
yexpr = expressions["expressions"]["y"]
|
| 73 |
+
exprconsts = {c: float(expressions["constants"][c]) for c in expressions["constants"]}
|
| 74 |
+
if len(exprconsts) > 4:
|
| 75 |
+
continue
|
| 76 |
+
yexpr = sp.parse_expr(yexpr)
|
| 77 |
+
exprconstsym = {c: sp.Symbol(c) for c in expressions["constants"]}
|
| 78 |
+
|
| 79 |
+
for opt in opts:
|
| 80 |
+
funcname = f"{mdlname}_run"
|
| 81 |
+
binf = os.path.join(basedir, mdl, opt, f"c_bin.elf")
|
| 82 |
+
D = DisassemblerARM32(binf)
|
| 83 |
+
diss = D.disassemble(funcname)
|
| 84 |
+
constants = D.constants
|
| 85 |
+
if len(constants) > 3:
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
exprconstsym.update({c: sp.Symbol(f"c{c}") for c in constants})
|
| 89 |
+
mapping, mapped = match_constants(exprconsts, constants, exprconstsym)
|
| 90 |
+
if len(mapped) != len(constants):
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
exprsubs = yexpr.subs(mapping)
|
| 94 |
+
exprprefix = replace_naming(sympy_to_prefix(exprsubs))
|
| 95 |
+
|
| 96 |
+
asmf.write(diss + "\n")
|
| 97 |
+
eqnf.write(" ".join(exprprefix) + "\n")
|
| 98 |
+
constf.write(json.dumps(constants) + "\n")
|
| 99 |
+
|
| 100 |
+
asmf.close()
|
| 101 |
+
eqnf.close()
|
| 102 |
+
constf.close()
|
remend/util.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import contextmanager
|
| 2 |
+
import signal
|
| 3 |
+
import sympy as sp
|
| 4 |
+
|
| 5 |
+
def timeout_handler(signum, frame):
|
| 6 |
+
raise TimeoutError("Block timed out")
|
| 7 |
+
@contextmanager
|
| 8 |
+
def timeout(duration):
|
| 9 |
+
signal.signal(signal.SIGALRM, timeout_handler)
|
| 10 |
+
signal.alarm(duration)
|
| 11 |
+
try:
|
| 12 |
+
yield
|
| 13 |
+
finally:
|
| 14 |
+
signal.alarm(0)
|
| 15 |
+
|
| 16 |
+
class DecodeError(Exception):
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
def sympy_expr_ok(expr):
|
| 20 |
+
atoms = expr.atoms()
|
| 21 |
+
return not (sp.I in atoms or sp.oo in atoms or sp.zoo in atoms or sp.nan in atoms)
|