aaljabari commited on
Commit
c3d885e
·
verified ·
1 Parent(s): 45aa939

Create bin/eval.py

Browse files
Files changed (1) hide show
  1. Nested/bin/eval.py +87 -0
Nested/bin/eval.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import argparse
4
+ from collections import namedtuple
5
+ from Nested.utils.helpers import load_checkpoint, make_output_dirs, logging_config
6
+ from Nested.utils.data import get_dataloaders, parse_conll_files
7
+ from Nested.utils.metrics import compute_single_label_metrics, compute_nested_metrics
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def parse_args():
13
+ parser = argparse.ArgumentParser(
14
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
15
+ )
16
+
17
+ parser.add_argument(
18
+ "--output_path",
19
+ type=str,
20
+ required=True,
21
+ help="Path to save results",
22
+ )
23
+
24
+ parser.add_argument(
25
+ "--model_path",
26
+ type=str,
27
+ required=True,
28
+ help="Model path",
29
+ )
30
+
31
+ parser.add_argument(
32
+ "--data_paths",
33
+ nargs="+",
34
+ type=str,
35
+ required=True,
36
+ help="Text or sequence to tag, this is in same format as training data with 'O' tag for all tokens",
37
+ )
38
+
39
+ parser.add_argument(
40
+ "--batch_size",
41
+ type=int,
42
+ default=32,
43
+ help="Batch size",
44
+ )
45
+
46
+ args = parser.parse_args()
47
+
48
+ return args
49
+
50
+
51
+ def main(args):
52
+ # Create directory to save predictions
53
+ make_output_dirs(args.output_path, overwrite=True)
54
+ logging_config(log_file=os.path.join(args.output_path, "eval.log"))
55
+
56
+ # Load tagger
57
+ tagger, tag_vocab, train_config = load_checkpoint(args.model_path)
58
+
59
+ # Convert text to a tagger dataset and index the tokens in args.text
60
+ datasets, vocab = parse_conll_files(args.data_paths)
61
+
62
+ vocabs = namedtuple("Vocab", ["tags", "tokens"])
63
+ vocab = vocabs(tokens=vocab.tokens, tags=tag_vocab)
64
+
65
+ # From the datasets generate the dataloaders
66
+ dataloaders = get_dataloaders(
67
+ datasets, vocab,
68
+ train_config.data_config,
69
+ batch_size=args.batch_size,
70
+ shuffle=[False] * len(datasets)
71
+ )
72
+
73
+ # Evaluate the model on each dataloader
74
+ for dataloader, input_file in zip(dataloaders, args.data_paths):
75
+ filename = os.path.basename(input_file)
76
+ predictions_file = os.path.join(args.output_path, f"predictions_{filename}")
77
+ _, segments, _, _ = tagger.eval(dataloader)
78
+ tagger.segments_to_file(segments, predictions_file)
79
+
80
+ if "Nested" in train_config.trainer_config["fn"]:
81
+ compute_nested_metrics(segments, vocab.tags[1:])
82
+ else:
83
+ compute_single_label_metrics(segments)
84
+
85
+
86
+ if __name__ == "__main__":
87
+ main(parse_args())