File size: 3,234 Bytes
7145fd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import sympy as sp
import numpy as np
import warnings
from sympy.abc import x
import sys
import json
from tqdm import tqdm

from .parser import parse_prefix_to_sympy, isint

# Ignore sympy lambda warnings.
warnings.simplefilter("ignore")

def percent(a, n):
    return f"{a/n*100:0.1f}%"

def do_eval_match(orig_expr, gen_expr):
    try:
        origl = sp.lambdify(x, orig_expr)
        genl = sp.lambdify(x, gen_expr)
        count = 0

        for v in np.arange(0.2, 1, 0.01):
            o = origl(v)
            g = genl(v)
            if o == float('nan') or o == float('inf'):
                continue
            if g == float('nan') or g == float('inf'):
                continue
            # if type(o) != np.float64 or type(g) != np.float64:
            #     print(orig_expr, o, gen_expr, g)
            #     return False
            if abs((o-g)/o) > 1e-5:
                return False
            count += 1
    except:
        return False
    return count >= 5

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser("Check generated expressions")
    parser.add_argument("-g", required=True, help="Generated expressions file")
    parser.add_argument("-c", required=True, help="Constants file")
    parser.add_argument("-e", required=True, help="Equations file")
    parser.add_argument("-r", required=True, help="Results file")
    args = parser.parse_args()

    gens = []
    with open(args.g, 'r') as genf, open(args.c) as constf, open(args.e) as eqnf:
        for line in tqdm(genf, desc="Reading file"):
            comps = line.strip().split("\t")
            if line[0] == 'H':
                num = int(comps[0][2:])
                tokens = comps[2].split(" ")
                eqn = next(eqnf)
                const = next(constf)
                const = json.loads(const.strip())
                gens.append((num, tokens, eqn.strip(), const))

    parsed = []
    matched = []
    results = []

    for n, toks, eqn, const in tqdm(gens, desc="Evaluating expressions"):
        res = {"id": n, "parsed": False, "matched": False, "orig": "", "gen": ""}
        if "<<unk>>" in toks:
            # Not parsed
            results.append(res)
            continue
        try:
            gen_expr = parse_prefix_to_sympy(toks)
        except Exception as e:
            # Not parsed
            results.append(res)
            continue

        res["parsed"] = True
        parsed.append(n)

        gen_expr = gen_expr.subs([(sp.Symbol("k"+c), const[c]) for c in const])
        orig_expr = sp.parse_expr(eqn, local_dict={"x0":x})
        res["orig"] = str(orig_expr)
        res["gen"] = str(gen_expr)

        if not do_eval_match(orig_expr, gen_expr):
            results.append(res)
            continue
        res["matched"] = True
        matched.append(n)
        results.append(res)

    with open(args.r, "w") as resf:
        for res in results:
            resf.write("{id} {parsed} {matched} \"{orig}\" \"{gen}\"\n".format(**res))
        resf.write("\n")
        N = len(gens)
        print("Total", N, file=resf)
        print("Parsed", len(parsed), percent(len(parsed), N), file=resf)
        print("Matched", len(matched), percent(len(matched), N), file=resf)