| import sympy as sp |
| import numpy as np |
| import warnings |
| from sympy.abc import x |
| import sys |
| import json |
| from tqdm import tqdm |
|
|
| from remend.tools.parser import parse_prefix_to_sympy |
|
|
| warnings.simplefilter("ignore") |
|
|
| def percent(a, n): |
| return f"{a/n*100:0.1f}%" |
|
|
| def do_eval_match(orig_expr, gen_expr): |
| try: |
| origl = sp.lambdify(x, orig_expr) |
| genl = sp.lambdify(x, gen_expr) |
| count = 0 |
|
|
| for v in np.arange(0.2, 1, 0.01): |
| o = origl(v) |
| g = genl(v) |
| if o != o or o == float('inf'): |
| continue |
| if g != g or g == float('inf'): |
| continue |
| |
| |
| |
| if abs((o-g)/o) > 1e-5: |
| return False |
| count += 1 |
| except: |
| return False |
| return count >= 5 |
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser("Check generated expressions") |
| parser.add_argument("-g", required=True, help="Generated expressions file") |
| parser.add_argument("-i", required=True, help="Info file") |
| parser.add_argument("-r", required=True, help="Results file") |
| args = parser.parse_args() |
|
|
| gens = [] |
| with open(args.g, 'r') as genf, open(args.i) as infof: |
| for line in tqdm(genf, desc="Reading file"): |
| comps = line.strip().split("\t") |
| if line[0] == 'H': |
| num = int(comps[0][2:]) |
| tokens = comps[2].split(" ") |
| info = next(infof) |
| info = json.loads(info.strip()) |
| if info["eqn"] == "": |
| continue |
| gens.append((num, tokens, info)) |
|
|
| parsed = [] |
| matched = [] |
| results = [] |
|
|
| for n, toks, info in tqdm(gens, desc="Evaluating expressions"): |
| res = {"id": n, "parsed": False, "matched": False, "orig": "", "gen": ""} |
| if "<<unk>>" in toks: |
| |
| results.append(res) |
| continue |
| try: |
| gen_expr = parse_prefix_to_sympy(toks) |
| except Exception as e: |
| |
| results.append(res) |
| continue |
|
|
| res["parsed"] = True |
| parsed.append(n) |
| const = info["constants"] |
|
|
| gen_expr = gen_expr.subs([(sp.Symbol("k"+c), const[c]) for c in const]) |
| orig_expr = sp.parse_expr(info["eqn"], local_dict={"x0":x}) |
| res["orig"] = str(orig_expr) |
| res["gen"] = str(gen_expr) |
|
|
| if not do_eval_match(orig_expr, gen_expr): |
| results.append(res) |
| continue |
| res["matched"] = True |
| matched.append(n) |
| results.append(res) |
|
|
| with open(args.r, "w") as resf: |
| for res in results: |
| resf.write("{id} {parsed} {matched} \"{orig}\" \"{gen}\"\n".format(**res)) |
| resf.write("\n") |
| N = len(gens) |
| print("Total", N, file=resf) |
| print("Parsed", len(parsed), percent(len(parsed), N), file=resf) |
| print("Matched", len(matched), percent(len(matched), N), file=resf) |
|
|