File size: 1,444 Bytes
c4ac745
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import argparse
import os

try:
    from evaluator import RelationalEvaluator, format_by_rank
except (ModuleNotFoundError, ImportError):
    import importlib
    import sys
    base_dir = os.path.dirname(__file__)
    full_path = os.path.abspath(os.path.join(base_dir, "..", "..", "evaluator.py"))
    spec = importlib.util.spec_from_file_location("evaluator", full_path)
    evaluator = importlib.util.module_from_spec(spec)
    sys.modules["evaluator"] = evaluator
    spec.loader.exec_module(evaluator)
    RelationalEvaluator = evaluator.RelationalEvaluator
    format_by_rank = evaluator.format_by_rank


class SMMEvaluator(RelationalEvaluator):
    hue_order = ["Real", "IRG", "IND", "CLD"]
    renames = {
        # "sdv": "HMA",
        "ind": "IND",
        # "rctgan": "RCTGAN",
        "clava": "CLD",
        "irg": "IRG"
    }

    def __init__(self):
        super().__init__("smm", models=["ind", "clava", "irg"])


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--skip-shapes", dest="shape", action="store_false", default=True)
    parser.add_argument("--skip-schema", dest="schema", action="store_false", default=True)
    return parser.parse_args()


def main():
    args = parse_args()
    evaluator = SMMEvaluator()
    if args.shape:
        evaluator.evaluate_shapes()
    if args.schema:
        evaluator.evaluate_schema()


if __name__ == '__main__':
    main()