naghamghanim commited on
Commit
f316449
·
verified ·
1 Parent(s): 8895e63

Upload 37 files

Browse files
Files changed (37) hide show
  1. Nested/__init__.py +0 -0
  2. Nested/__pycache__/__init__.cpython-311.pyc +0 -0
  3. Nested/bin/__init__.py +0 -0
  4. Nested/bin/eval.py +87 -0
  5. Nested/bin/infer.py +73 -0
  6. Nested/bin/process.py +140 -0
  7. Nested/bin/train.py +222 -0
  8. Nested/data/__init__.py +0 -0
  9. Nested/data/__pycache__/__init__.cpython-311.pyc +0 -0
  10. Nested/data/__pycache__/datasets.cpython-311.pyc +0 -0
  11. Nested/data/__pycache__/transforms.cpython-311.pyc +0 -0
  12. Nested/data/datasets.py +150 -0
  13. Nested/data/transforms.py +127 -0
  14. Nested/nn/BaseModel.py +22 -0
  15. Nested/nn/BertNestedTagger.py +34 -0
  16. Nested/nn/BertSeqTagger.py +4 -1
  17. Nested/nn/__init__.py +3 -0
  18. Nested/nn/__pycache__/BaseModel.cpython-311.pyc +0 -0
  19. Nested/nn/__pycache__/BertNestedTagger.cpython-311.pyc +0 -0
  20. Nested/nn/__pycache__/BertSeqTagger.cpython-311.pyc +0 -0
  21. Nested/nn/__pycache__/__init__.cpython-311.pyc +0 -0
  22. Nested/trainers/BaseTrainer.py +117 -0
  23. Nested/trainers/BertNestedTrainer.py +203 -0
  24. Nested/trainers/BertTrainer.py +163 -0
  25. Nested/trainers/__init__.py +3 -0
  26. Nested/trainers/__pycache__/BaseTrainer.cpython-311.pyc +0 -0
  27. Nested/trainers/__pycache__/BertNestedTrainer.cpython-311.pyc +0 -0
  28. Nested/trainers/__pycache__/BertTrainer.cpython-311.pyc +0 -0
  29. Nested/trainers/__pycache__/__init__.cpython-311.pyc +0 -0
  30. Nested/utils/__init__.py +0 -0
  31. Nested/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  32. Nested/utils/__pycache__/data.cpython-311.pyc +0 -0
  33. Nested/utils/__pycache__/helpers.cpython-311.pyc +0 -0
  34. Nested/utils/__pycache__/metrics.cpython-311.pyc +0 -0
  35. Nested/utils/data.py +112 -38
  36. Nested/utils/helpers.py +117 -0
  37. Nested/utils/metrics.py +69 -0
Nested/__init__.py ADDED
File without changes
Nested/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (149 Bytes). View file
 
Nested/bin/__init__.py ADDED
File without changes
Nested/bin/eval.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import argparse
4
+ from collections import namedtuple
5
+ from Nested.utils.helpers import load_checkpoint, make_output_dirs, logging_config
6
+ from Nested.utils.data import get_dataloaders, parse_conll_files
7
+ from Nested.utils.metrics import compute_single_label_metrics, compute_nested_metrics
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def parse_args():
13
+ parser = argparse.ArgumentParser(
14
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
15
+ )
16
+
17
+ parser.add_argument(
18
+ "--output_path",
19
+ type=str,
20
+ required=True,
21
+ help="Path to save results",
22
+ )
23
+
24
+ parser.add_argument(
25
+ "--model_path",
26
+ type=str,
27
+ required=True,
28
+ help="Model path",
29
+ )
30
+
31
+ parser.add_argument(
32
+ "--data_paths",
33
+ nargs="+",
34
+ type=str,
35
+ required=True,
36
+ help="Text or sequence to tag, this is in same format as training data with 'O' tag for all tokens",
37
+ )
38
+
39
+ parser.add_argument(
40
+ "--batch_size",
41
+ type=int,
42
+ default=32,
43
+ help="Batch size",
44
+ )
45
+
46
+ args = parser.parse_args()
47
+
48
+ return args
49
+
50
+
51
+ def main(args):
52
+ # Create directory to save predictions
53
+ make_output_dirs(args.output_path, overwrite=True)
54
+ logging_config(log_file=os.path.join(args.output_path, "eval.log"))
55
+
56
+ # Load tagger
57
+ tagger, tag_vocab, train_config = load_checkpoint(args.model_path)
58
+
59
+ # Convert text to a tagger dataset and index the tokens in args.text
60
+ datasets, vocab = parse_conll_files(args.data_paths)
61
+
62
+ vocabs = namedtuple("Vocab", ["tags", "tokens"])
63
+ vocab = vocabs(tokens=vocab.tokens, tags=tag_vocab)
64
+
65
+ # From the datasets generate the dataloaders
66
+ dataloaders = get_dataloaders(
67
+ datasets, vocab,
68
+ train_config.data_config,
69
+ batch_size=args.batch_size,
70
+ shuffle=[False] * len(datasets)
71
+ )
72
+
73
+ # Evaluate the model on each dataloader
74
+ for dataloader, input_file in zip(dataloaders, args.data_paths):
75
+ filename = os.path.basename(input_file)
76
+ predictions_file = os.path.join(args.output_path, f"predictions_{filename}")
77
+ _, segments, _, _ = tagger.eval(dataloader)
78
+ tagger.segments_to_file(segments, predictions_file)
79
+
80
+ if "Nested" in train_config.trainer_config["fn"]:
81
+ compute_nested_metrics(segments, vocab.tags[1:])
82
+ else:
83
+ compute_single_label_metrics(segments)
84
+
85
+
86
+ if __name__ == "__main__":
87
+ main(parse_args())
Nested/bin/infer.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import argparse
3
+ from collections import namedtuple
4
+ from Nested.utils.helpers import load_checkpoint
5
+ from Nested.utils.data import get_dataloaders, text2segments
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def parse_args():
11
+ parser = argparse.ArgumentParser(
12
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
13
+ )
14
+
15
+ parser.add_argument(
16
+ "--model_path",
17
+ type=str,
18
+ required=True,
19
+ help="Model path",
20
+ )
21
+
22
+ parser.add_argument(
23
+ "--text",
24
+ type=str,
25
+ required=True,
26
+ help="Text or sequence to tag",
27
+ )
28
+
29
+ parser.add_argument(
30
+ "--batch_size",
31
+ type=int,
32
+ default=32,
33
+ help="Batch size",
34
+ )
35
+
36
+ args = parser.parse_args()
37
+
38
+ return args
39
+
40
+
41
+ def main(args):
42
+ # Load tagger
43
+ tagger, tag_vocab, train_config = load_checkpoint(args.model_path)
44
+
45
+ # Convert text to a tagger dataset and index the tokens in args.text
46
+ dataset, token_vocab = text2segments(args.text)
47
+
48
+ vocabs = namedtuple("Vocab", ["tags", "tokens"])
49
+ vocab = vocabs(tokens=token_vocab, tags=tag_vocab)
50
+
51
+ # From the datasets generate the dataloaders
52
+ dataloader = get_dataloaders(
53
+ (dataset,),
54
+ vocab,
55
+ train_config.data_config,
56
+ batch_size=args.batch_size,
57
+ shuffle=(False,),
58
+ )[0]
59
+
60
+ # Perform inference on the text and get back the tagged segments
61
+ segments = tagger.infer(dataloader)
62
+
63
+ # Print results
64
+ for segment in segments:
65
+ s = [
66
+ f"{token.text} ({'|'.join([t['tag'] for t in token.pred_tag])})"
67
+ for token in segment
68
+ ]
69
+ print(" ".join(s))
70
+
71
+
72
+ if __name__ == "__main__":
73
+ main(parse_args())
Nested/bin/process.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import csv
4
+ import logging
5
+ import numpy as np
6
+ from Nested.utils.helpers import logging_config
7
+ from Nested.utils.data import conll_to_segments
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def to_conll_format(input_files, output_path, multi_label=False):
13
+ """
14
+ Parse data files and convert them into CoNLL format
15
+ :param input_files: List[str] - list of filenames
16
+ :param output_path: str - output path
17
+ :param multi_label: boolean - True to process data with mutli-class/multi-label
18
+ :return:
19
+ """
20
+ for input_file in input_files:
21
+ tokens = list()
22
+ prev_sent_id = None
23
+
24
+ with open(input_file, "r") as fh:
25
+ r = csv.reader(fh, delimiter="\t", quotechar=" ")
26
+ next(r)
27
+
28
+ for row in r:
29
+ sent_id, token, labels = row[1], row[3], row[4].split()
30
+ valid_labels = sum([1 for l in labels if "-" in l or l == "O"]) == len(labels)
31
+
32
+ if not valid_labels:
33
+ logging.warning("Invalid labels found %s", str(row))
34
+ continue
35
+ if not labels:
36
+ logging.warning("Token %s has no label", str(row))
37
+ continue
38
+ if not token:
39
+ logging.warning("Token %s is missing", str(row))
40
+ continue
41
+ if len(token.split()) > 1:
42
+ logging.warning("Token %s has multiple tokens", str(row))
43
+ continue
44
+
45
+ if prev_sent_id is not None and sent_id != prev_sent_id:
46
+ tokens.append([])
47
+
48
+ if multi_label:
49
+ tokens.append([token] + labels)
50
+ else:
51
+ tokens.append([token, labels[0]])
52
+
53
+ prev_sent_id = sent_id
54
+
55
+ num_segments = sum([1 for token in tokens if not token])
56
+ logging.info("Found %d segments and %d tokens in %s", num_segments + 1, len(tokens) - num_segments, input_file)
57
+
58
+ filename = os.path.basename(input_file)
59
+ output_file = os.path.join(output_path, filename)
60
+
61
+ with open(output_file, "w") as fh:
62
+ fh.write("\n".join(" ".join(token) for token in tokens))
63
+ logging.info("Output file %s", output_file)
64
+
65
+
66
+ def train_dev_test_split(input_files, output_path, train_ratio, dev_ratio):
67
+ segments = list()
68
+ filenames = ["train.txt", "val.txt", "test.txt"]
69
+
70
+ for input_file in input_files:
71
+ segments += conll_to_segments(input_file)
72
+
73
+ n = len(segments)
74
+ np.random.shuffle(segments)
75
+ datasets = np.split(segments, [int(train_ratio*n), int((train_ratio+dev_ratio)*n)])
76
+
77
+ # write data to files
78
+ for i in range(len(datasets)):
79
+ filename = os.path.join(output_path, filenames[i])
80
+
81
+ with open(filename, "w") as fh:
82
+ text = "\n\n".join(["\n".join([f"{token.text} {' '.join(token.gold_tag)}" for token in segment]) for segment in datasets[i]])
83
+ fh.write(text)
84
+ logging.info("Output file %s", filename)
85
+
86
+
87
+ def main(args):
88
+ if args.task == "to_conll_format":
89
+ to_conll_format(args.input_files, args.output_path, multi_label=args.multi_label)
90
+ if args.task == "train_dev_test_split":
91
+ train_dev_test_split(args.input_files, args.output_path, args.train_ratio, args.dev_ratio)
92
+
93
+
94
+ if __name__ == "__main__":
95
+ parser = argparse.ArgumentParser(
96
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
97
+ )
98
+
99
+ parser.add_argument(
100
+ "--input_files",
101
+ type=str,
102
+ nargs="+",
103
+ required=True,
104
+ help="List of input files",
105
+ )
106
+
107
+ parser.add_argument(
108
+ "--output_path",
109
+ type=str,
110
+ required=True,
111
+ help="Output path",
112
+ )
113
+
114
+ parser.add_argument(
115
+ "--train_ratio",
116
+ type=float,
117
+ required=False,
118
+ help="Training data ratio (percent of segments). Required with the task train_dev_test_split. "
119
+ "Files must in ConLL format",
120
+ )
121
+
122
+ parser.add_argument(
123
+ "--dev_ratio",
124
+ type=float,
125
+ required=False,
126
+ help="Dev/val data ratio (percent of segments). Required with the task train_dev_test_split. "
127
+ "Files must in ConLL format",
128
+ )
129
+
130
+ parser.add_argument(
131
+ "--task", required=True, choices=["to_conll_format", "train_dev_test_split"]
132
+ )
133
+
134
+ parser.add_argument(
135
+ "--multi_label", action='store_true'
136
+ )
137
+
138
+ args = parser.parse_args()
139
+ logging_config(os.path.join(args.output_path, "process.log"))
140
+ main(args)
Nested/bin/train.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import json
4
+ import argparse
5
+ import torch.utils.tensorboard
6
+ from torchvision import *
7
+ import pickle
8
+ from Nested.utils.data import get_dataloaders, parse_conll_files
9
+ from Nested.utils.helpers import logging_config, load_object, make_output_dirs, set_seed
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def parse_args():
15
+ parser = argparse.ArgumentParser(
16
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
17
+ )
18
+
19
+ parser.add_argument(
20
+ "--output_path",
21
+ type=str,
22
+ required=True,
23
+ help="Output path",
24
+ )
25
+
26
+ parser.add_argument(
27
+ "--train_path",
28
+ type=str,
29
+ required=True,
30
+ help="Path to training data",
31
+ )
32
+
33
+ parser.add_argument(
34
+ "--val_path",
35
+ type=str,
36
+ required=True,
37
+ help="Path to training data",
38
+ )
39
+
40
+ parser.add_argument(
41
+ "--test_path",
42
+ type=str,
43
+ required=True,
44
+ help="Path to training data",
45
+ )
46
+
47
+ parser.add_argument(
48
+ "--bert_model",
49
+ type=str,
50
+ default="aubmindlab/bert-base-arabertv2",
51
+ help="BERT model",
52
+ )
53
+
54
+ parser.add_argument(
55
+ "--gpus",
56
+ type=int,
57
+ nargs="+",
58
+ default=[0],
59
+ help="GPU IDs to train on",
60
+ )
61
+
62
+ parser.add_argument(
63
+ "--log_interval",
64
+ type=int,
65
+ default=10,
66
+ help="Log results every that many timesteps",
67
+ )
68
+
69
+ parser.add_argument(
70
+ "--batch_size",
71
+ type=int,
72
+ default=32,
73
+ help="Batch size",
74
+ )
75
+
76
+ parser.add_argument(
77
+ "--num_workers",
78
+ type=int,
79
+ default=0,
80
+ help="Dataloader number of workers",
81
+ )
82
+
83
+ parser.add_argument(
84
+ "--data_config",
85
+ type=json.loads,
86
+ default='{"fn": "Nested.data.datasets.DefaultDataset", "kwargs": {"max_seq_len": 512}}',
87
+ help="Dataset configurations",
88
+ )
89
+
90
+ parser.add_argument(
91
+ "--trainer_config",
92
+ type=json.loads,
93
+ default='{"fn": "Nested.trainers.BertTrainer", "kwargs": {"max_epochs": 50}}',
94
+ help="Trainer configurations",
95
+ )
96
+
97
+ parser.add_argument(
98
+ "--network_config",
99
+ type=json.loads,
100
+ default='{"fn": "Nested.nn.BertSeqTagger", "kwargs": '
101
+ '{"dropout": 0.1, "bert_model": "aubmindlab/bert-base-arabertv2"}}',
102
+ help="Network configurations",
103
+ )
104
+
105
+ parser.add_argument(
106
+ "--optimizer",
107
+ type=json.loads,
108
+ default='{"fn": "torch.optim.AdamW", "kwargs": {"lr": 0.0001}}',
109
+ help="Optimizer configurations",
110
+ )
111
+
112
+ parser.add_argument(
113
+ "--lr_scheduler",
114
+ type=json.loads,
115
+ default='{"fn": "torch.optim.lr_scheduler.ExponentialLR", "kwargs": {"gamma": 1}}',
116
+ help="Learning rate scheduler configurations",
117
+ )
118
+
119
+ parser.add_argument(
120
+ "--loss",
121
+ type=json.loads,
122
+ default='{"fn": "torch.nn.CrossEntropyLoss", "kwargs": {}}',
123
+ help="Loss function configurations",
124
+ )
125
+
126
+ parser.add_argument(
127
+ "--overwrite",
128
+ action="store_true",
129
+ help="Overwrite output directory",
130
+ )
131
+
132
+ parser.add_argument(
133
+ "--seed",
134
+ type=int,
135
+ default=1,
136
+ help="Seed for random initialization",
137
+ )
138
+
139
+ args = parser.parse_args()
140
+
141
+ return args
142
+
143
+
144
+ def main(args):
145
+ make_output_dirs(
146
+ args.output_path,
147
+ subdirs=("tensorboard", "checkpoints"),
148
+ overwrite=args.overwrite,
149
+ )
150
+
151
+ # Set the seed for randomization
152
+ set_seed(args.seed)
153
+
154
+ logging_config(os.path.join(args.output_path, "train.log"))
155
+ summary_writer = torch.utils.tensorboard.SummaryWriter(
156
+ os.path.join(args.output_path, "tensorboard")
157
+ )
158
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(gpu) for gpu in args.gpus])
159
+
160
+ # Get the datasets and vocab for tags and tokens
161
+ datasets, vocab = parse_conll_files((args.train_path, args.val_path, args.test_path))
162
+
163
+ if "Nested" in args.network_config["fn"]:
164
+ args.network_config["kwargs"]["num_labels"] = [len(v) for v in vocab.tags[1:]]
165
+ else:
166
+ args.network_config["kwargs"]["num_labels"] = len(vocab.tags[0])
167
+
168
+ args.data_config["kwargs"]["bert_model"] = args.network_config["kwargs"]["bert_model"]
169
+
170
+ # Save tag vocab to desk
171
+ with open(os.path.join(args.output_path, "tag_vocab.pkl"), "wb") as fh:
172
+ pickle.dump(vocab.tags, fh)
173
+
174
+ # Write config to file
175
+ args_file = os.path.join(args.output_path, "args.json")
176
+ with open(args_file, "w") as fh:
177
+ logger.info("Writing config to %s", args_file)
178
+ json.dump(args.__dict__, fh, indent=4)
179
+
180
+ # From the datasets generate the dataloaders
181
+ train_dataloader, val_dataloader, test_dataloader = get_dataloaders(
182
+ datasets, vocab, args.data_config, args.batch_size, args.num_workers
183
+ )
184
+
185
+ model = load_object(args.network_config["fn"], args.network_config["kwargs"])
186
+ model = torch.nn.DataParallel(model, device_ids=range(len(args.gpus)))
187
+
188
+ if torch.cuda.is_available():
189
+ model = model.cuda()
190
+
191
+ args.optimizer["kwargs"]["params"] = model.parameters()
192
+ optimizer = load_object(args.optimizer["fn"], args.optimizer["kwargs"])
193
+
194
+ args.lr_scheduler["kwargs"]["optimizer"] = optimizer
195
+ if "num_training_steps" in args.lr_scheduler["kwargs"]:
196
+ args.lr_scheduler["kwargs"]["num_training_steps"] = args.max_epochs * len(
197
+ train_dataloader
198
+ )
199
+
200
+ scheduler = load_object(args.lr_scheduler["fn"], args.lr_scheduler["kwargs"])
201
+ loss = load_object(args.loss["fn"], args.loss["kwargs"])
202
+
203
+ args.trainer_config["kwargs"].update({
204
+ "model": model,
205
+ "optimizer": optimizer,
206
+ "scheduler": scheduler,
207
+ "loss": loss,
208
+ "train_dataloader": train_dataloader,
209
+ "val_dataloader": val_dataloader,
210
+ "test_dataloader": test_dataloader,
211
+ "log_interval": args.log_interval,
212
+ "summary_writer": summary_writer,
213
+ "output_path": args.output_path
214
+ })
215
+
216
+ trainer = load_object(args.trainer_config["fn"], args.trainer_config["kwargs"])
217
+ trainer.train()
218
+ return
219
+
220
+
221
+ if __name__ == "__main__":
222
+ main(parse_args())
Nested/data/__init__.py ADDED
File without changes
Nested/data/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (154 Bytes). View file
 
Nested/data/__pycache__/datasets.cpython-311.pyc ADDED
Binary file (7.31 kB). View file
 
Nested/data/__pycache__/transforms.cpython-311.pyc ADDED
Binary file (9.26 kB). View file
 
Nested/data/datasets.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ from torch.nn.utils.rnn import pad_sequence
5
+ from Nested.data.transforms import (
6
+ BertSeqTransform,
7
+ NestedTagsTransform
8
+ )
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class Token:
14
+ def __init__(self, text=None, pred_tag=None, gold_tag=None):
15
+ """
16
+ Token object to hold token attributes
17
+ :param text: str
18
+ :param pred_tag: str
19
+ :param gold_tag: str
20
+ """
21
+ self.text = text
22
+ self.gold_tag = gold_tag
23
+ self.pred_tag = pred_tag
24
+ self.subwords = None
25
+
26
+ @property
27
+ def subwords(self):
28
+ return self._subwords
29
+
30
+ @subwords.setter
31
+ def subwords(self, value):
32
+ self._subwords = value
33
+
34
+ def __str__(self):
35
+ """
36
+ Token text representation
37
+ :return: str
38
+ """
39
+ gold_tags = "|".join(self.gold_tag)
40
+
41
+ if self.pred_tag:
42
+ pred_tags = "|".join([pred_tag["tag"] for pred_tag in self.pred_tag])
43
+ else:
44
+ pred_tags = ""
45
+
46
+ if self.gold_tag:
47
+ r = f"{self.text}\t{gold_tags}\t{pred_tags}"
48
+ else:
49
+ r = f"{self.text}\t{pred_tags}"
50
+
51
+ return r
52
+
53
+
54
+ class DefaultDataset(Dataset):
55
+ def __init__(
56
+ self,
57
+ examples=None,
58
+ vocab=None,
59
+ bert_model="aubmindlab/bert-base-arabertv2",
60
+ max_seq_len=512,
61
+ ):
62
+ """
63
+ The dataset that used to transform the segments into training data
64
+ :param examples: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]]
65
+ You can get generate examples from -- Nested.data.dataset.parse_conll_files
66
+ :param vocab: vocab object containing indexed tags and tokens
67
+ :param bert_model: str - BERT model
68
+ :param: int - maximum sequence length
69
+ """
70
+ self.transform = BertSeqTransform(bert_model, vocab, max_seq_len=max_seq_len)
71
+ self.examples = examples
72
+ self.vocab = vocab
73
+
74
+ def __len__(self):
75
+ return len(self.examples)
76
+
77
+ def __getitem__(self, item):
78
+ subwords, tags, tokens, valid_len = self.transform(self.examples[item])
79
+ return subwords, tags, tokens, valid_len
80
+
81
+ def collate_fn(self, batch):
82
+ """
83
+ Collate function that is called when the batch is called by the trainer
84
+ :param batch: Dataloader batch
85
+ :return: Same output as the __getitem__ function
86
+ """
87
+ subwords, tags, tokens, valid_len = zip(*batch)
88
+
89
+ # Pad sequences in this batch
90
+ # subwords and tokens are padded with zeros
91
+ # tags are padding with the index of the O tag
92
+ subwords = pad_sequence(subwords, batch_first=True, padding_value=0)
93
+ tags = pad_sequence(
94
+ tags, batch_first=True, padding_value=self.vocab.tags[0].get_stoi()["O"]
95
+ )
96
+ return subwords, tags, tokens, valid_len
97
+
98
+
99
+ class NestedTagsDataset(Dataset):
100
+ def __init__(
101
+ self,
102
+ examples=None,
103
+ vocab=None,
104
+ bert_model="aubmindlab/bert-base-arabertv2",
105
+ max_seq_len=512,
106
+ ):
107
+ """
108
+ The dataset that used to transform the segments into training data
109
+ :param examples: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]]
110
+ You can get generate examples from -- Nested.data.dataset.parse_conll_files
111
+ :param vocab: vocab object containing indexed tags and tokens
112
+ :param bert_model: str - BERT model
113
+ :param: int - maximum sequence length
114
+ """
115
+ self.transform = NestedTagsTransform(
116
+ bert_model, vocab, max_seq_len=max_seq_len
117
+ )
118
+ self.examples = examples
119
+ self.vocab = vocab
120
+
121
+ def __len__(self):
122
+ return len(self.examples)
123
+
124
+ def __getitem__(self, item):
125
+ subwords, tags, tokens, masks, valid_len = self.transform(self.examples[item])
126
+ return subwords, tags, tokens, masks, valid_len
127
+
128
+ def collate_fn(self, batch):
129
+ """
130
+ Collate function that is called when the batch is called by the trainer
131
+ :param batch: Dataloader batch
132
+ :return: Same output as the __getitem__ function
133
+ """
134
+ subwords, tags, tokens, masks, valid_len = zip(*batch)
135
+
136
+ # Pad sequences in this batch
137
+ # subwords and tokens are padded with zeros
138
+ # tags are padding with the index of the O tag
139
+ subwords = pad_sequence(subwords, batch_first=True, padding_value=0)
140
+
141
+ masks = [torch.nn.ConstantPad1d((0, subwords.shape[-1] - tag.shape[-1]), 0)(mask)
142
+ for tag, mask in zip(tags, masks)]
143
+ masks = torch.cat(masks)
144
+
145
+ # Pad the tags, do the padding for each tag type
146
+ tags = [torch.nn.ConstantPad1d((0, subwords.shape[-1] - tag.shape[-1]), vocab.get_stoi()["O"])(tag)
147
+ for tag, vocab in zip(tags, self.vocab.tags[1:])]
148
+ tags = torch.cat(tags)
149
+
150
+ return subwords, tags, tokens, masks, valid_len
Nested/data/transforms.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import BertTokenizer
3
+ from functools import partial
4
+ import logging
5
+ import re
6
+ import itertools
7
+ import Nested
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class BertSeqTransform:
13
+ def __init__(self, bert_model, vocab, max_seq_len=512):
14
+ self.tokenizer = BertTokenizer.from_pretrained(bert_model)
15
+ self.encoder = partial(
16
+ self.tokenizer.encode,
17
+ max_length=max_seq_len,
18
+ truncation=True,
19
+ )
20
+ self.max_seq_len = max_seq_len
21
+ self.vocab = vocab
22
+
23
+ def __call__(self, segment):
24
+ subwords, tags, tokens = list(), list(), list()
25
+ unk_token = Nested.data.datasets.Token(text="UNK")
26
+
27
+ for token in segment:
28
+ # Sometimes the tokenizer fails to encode the word and return no input_ids, in that case, we use
29
+ # the input_id for [UNK]
30
+ token_subwords = self.encoder(token.text)[1:-1] or self.encoder("[UNK]")[1:-1]
31
+ subwords += token_subwords
32
+ tags += [self.vocab.tags[0].get_stoi()[token.gold_tag[0]]] + [self.vocab.tags[0].get_stoi()["O"]] * (len(token_subwords) - 1)
33
+ tokens += [token] + [unk_token] * (len(token_subwords) - 1)
34
+
35
+ # Truncate to max_seq_len
36
+ if len(subwords) > self.max_seq_len - 2:
37
+ text = " ".join([t.text for t in tokens if t.text != "UNK"])
38
+ logger.info("Truncating the sequence %s to %d", text, self.max_seq_len - 2)
39
+ subwords = subwords[:self.max_seq_len - 2]
40
+ tags = tags[:self.max_seq_len - 2]
41
+ tokens = tokens[:self.max_seq_len - 2]
42
+
43
+ subwords.insert(0, self.tokenizer.cls_token_id)
44
+ subwords.append(self.tokenizer.sep_token_id)
45
+
46
+ tags.insert(0, self.vocab.tags[0].get_stoi()["O"])
47
+ tags.append(self.vocab.tags[0].get_stoi()["O"])
48
+
49
+ tokens.insert(0, unk_token)
50
+ tokens.append(unk_token)
51
+
52
+ return torch.LongTensor(subwords), torch.LongTensor(tags), tokens, len(tokens)
53
+
54
+
55
+ class NestedTagsTransform:
56
+ def __init__(self, bert_model, vocab, max_seq_len=512):
57
+ self.tokenizer = BertTokenizer.from_pretrained(bert_model)
58
+ self.encoder = partial(
59
+ self.tokenizer.encode,
60
+ max_length=max_seq_len,
61
+ truncation=True,
62
+ )
63
+ self.max_seq_len = max_seq_len
64
+ self.vocab = vocab
65
+
66
+ def __call__(self, segment):
67
+ tags, tokens, subwords = list(), list(), list()
68
+ unk_token = Nested.data.datasets.Token(text="UNK")
69
+
70
+ # Encode each token and get its subwords and IDs
71
+ for token in segment:
72
+ # Sometimes the tokenizer fails to encode the word and return no input_ids, in that case, we use
73
+ # the input_id for [UNK]
74
+ token.subwords = self.encoder(token.text)[1:-1] or self.encoder("[UNK]")[1:-1]
75
+ subwords += token.subwords
76
+ tokens += [token] + [unk_token] * (len(token.subwords) - 1)
77
+
78
+ # Construct the labels for each tag type
79
+ # The sequence will have a list of tags for each type
80
+ # The final tags for a sequence is a matrix NUM_TAG_TYPES x SEQ_LEN
81
+ # Example:
82
+ # [
83
+ # [O, O, B-PERS, I-PERS, O, O, O]
84
+ # [B-ORG, I-ORG, O, O, O, O, O]
85
+ # [O, O, O, O, O, O, B-GPE]
86
+ # ]
87
+ for vocab in self.vocab.tags[1:]:
88
+ vocab_tags = "|".join(["^" + t + "$" for t in vocab.get_itos() if "-" in t])
89
+ r = re.compile(vocab_tags)
90
+
91
+ # This is really messy
92
+ # For a given token we find a matching tag_name, BUT we might find
93
+ # multiple matches (i.e. a token can be labeled B-ORG and I-ORG) in this
94
+ # case we get only the first tag as we do not have overlapping of same type
95
+ single_type_tags = [[(list(filter(r.match, token.gold_tag))
96
+ or ["O"])[0]] + ["O"] * (len(token.subwords) - 1)
97
+ for token in segment]
98
+ single_type_tags = list(itertools.chain(*single_type_tags))
99
+ tags.append([vocab.get_stoi()[tag] for tag in single_type_tags])
100
+
101
+ # Truncate to max_seq_len
102
+ if len(subwords) > self.max_seq_len - 2:
103
+ text = " ".join([t.text for t in tokens if t.text != "UNK"])
104
+ logger.info("Truncating the sequence %s to %d", text, self.max_seq_len - 2)
105
+ subwords = subwords[:self.max_seq_len - 2]
106
+ tags = [t[:self.max_seq_len - 2] for t in tags]
107
+ tokens = tokens[:self.max_seq_len - 2]
108
+
109
+ # Add dummy token at the start end of sequence
110
+ tokens.insert(0, unk_token)
111
+ tokens.append(unk_token)
112
+
113
+ # Add CLS and SEP at start end of subwords
114
+ subwords.insert(0, self.tokenizer.cls_token_id)
115
+ subwords.append(self.tokenizer.sep_token_id)
116
+ subwords = torch.LongTensor(subwords)
117
+
118
+ # Add "O" tags for the first and last subwords
119
+ tags = torch.Tensor(tags)
120
+ tags = torch.column_stack((
121
+ torch.Tensor([vocab.get_stoi()["O"] for vocab in self.vocab.tags[1:]]),
122
+ tags,
123
+ torch.Tensor([vocab.get_stoi()["O"] for vocab in self.vocab.tags[1:]]),
124
+ )).unsqueeze(0)
125
+
126
+ mask = torch.ones_like(tags)
127
+ return subwords, tags, tokens, mask, len(tokens)
Nested/nn/BaseModel.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from transformers import BertModel
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ class BaseModel(nn.Module):
9
+ def __init__(self,
10
+ bert_model="aubmindlab/bert-base-arabertv2",
11
+ num_labels=2,
12
+ dropout=0.1,
13
+ num_types=0):
14
+ super().__init__()
15
+
16
+ self.bert_model = bert_model
17
+ self.num_labels = num_labels
18
+ self.num_types = num_types
19
+ self.dropout = dropout
20
+
21
+ self.bert = BertModel.from_pretrained(bert_model)
22
+ self.dropout = nn.Dropout(dropout)
Nested/nn/BertNestedTagger.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from Nested.nn import BaseModel
4
+
5
+
6
+ class BertNestedTagger(BaseModel):
7
+ def __init__(self, **kwargs):
8
+ super(BertNestedTagger, self).__init__(**kwargs)
9
+
10
+ self.max_num_labels = max(self.num_labels)
11
+ classifiers = [nn.Linear(768, num_labels) for num_labels in self.num_labels]
12
+ self.classifiers = torch.nn.Sequential(*classifiers)
13
+
14
+ def forward(self, x):
15
+ y = self.bert(x)
16
+ y = self.dropout(y["last_hidden_state"])
17
+ output = list()
18
+
19
+ for i, classifier in enumerate(self.classifiers):
20
+ logits = classifier(y)
21
+
22
+ # Pad logits to allow Multi-GPU/DataParallel training to work
23
+ # We will truncate the padded dimensions when we compute the loss in the trainer
24
+ logits = torch.nn.ConstantPad1d((0, self.max_num_labels - logits.shape[-1]), 0)(logits)
25
+ output.append(logits)
26
+
27
+ # Return tensor of the shape B x T x L x C
28
+ # B: batch size
29
+ # T: sequence length
30
+ # L: number of tag types
31
+ # C: number of classes per tag type
32
+ output = torch.stack(output).permute((1, 2, 0, 3))
33
+ return output
34
+
Nested/nn/BertSeqTagger.py CHANGED
@@ -1,14 +1,17 @@
1
  import torch.nn as nn
2
  from transformers import BertModel
3
 
 
4
  class BertSeqTagger(nn.Module):
5
  def __init__(self, bert_model, num_labels=2, dropout=0.1):
6
  super().__init__()
 
7
  self.bert = BertModel.from_pretrained(bert_model)
8
  self.dropout = nn.Dropout(dropout)
9
  self.linear = nn.Linear(768, num_labels)
 
10
  def forward(self, x):
11
  y = self.bert(x)
12
  y = self.dropout(y["last_hidden_state"])
13
  logits = self.linear(y)
14
- return logits
 
1
  import torch.nn as nn
2
  from transformers import BertModel
3
 
4
+
5
  class BertSeqTagger(nn.Module):
6
  def __init__(self, bert_model, num_labels=2, dropout=0.1):
7
  super().__init__()
8
+
9
  self.bert = BertModel.from_pretrained(bert_model)
10
  self.dropout = nn.Dropout(dropout)
11
  self.linear = nn.Linear(768, num_labels)
12
+
13
  def forward(self, x):
14
  y = self.bert(x)
15
  y = self.dropout(y["last_hidden_state"])
16
  logits = self.linear(y)
17
+ return logits
Nested/nn/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from Nested.nn.BaseModel import BaseModel
2
+ from Nested.nn.BertSeqTagger import BertSeqTagger
3
+ from Nested.nn.BertNestedTagger import BertNestedTagger
Nested/nn/__pycache__/BaseModel.cpython-311.pyc ADDED
Binary file (1.34 kB). View file
 
Nested/nn/__pycache__/BertNestedTagger.cpython-311.pyc ADDED
Binary file (2.33 kB). View file
 
Nested/nn/__pycache__/BertSeqTagger.cpython-311.pyc ADDED
Binary file (1.54 kB). View file
 
Nested/nn/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (379 Bytes). View file
 
Nested/trainers/BaseTrainer.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import logging
4
+ import natsort
5
+ import glob
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class BaseTrainer:
11
+ def __init__(
12
+ self,
13
+ model=None,
14
+ max_epochs=50,
15
+ optimizer=None,
16
+ scheduler=None,
17
+ loss=None,
18
+ train_dataloader=None,
19
+ val_dataloader=None,
20
+ test_dataloader=None,
21
+ log_interval=10,
22
+ summary_writer=None,
23
+ output_path=None,
24
+ clip=5,
25
+ patience=5
26
+ ):
27
+ self.model = model
28
+ self.max_epochs = max_epochs
29
+ self.train_dataloader = train_dataloader
30
+ self.val_dataloader = val_dataloader
31
+ self.test_dataloader = test_dataloader
32
+ self.optimizer = optimizer
33
+ self.scheduler = scheduler
34
+ self.loss = loss
35
+ self.log_interval = log_interval
36
+ self.summary_writer = summary_writer
37
+ self.output_path = output_path
38
+ self.current_timestep = 0
39
+ self.current_epoch = 0
40
+ self.clip = clip
41
+ self.patience = patience
42
+
43
+ def tag(self, dataloader, is_train=True):
44
+ """
45
+ Given a dataloader containing segments, predict the tags
46
+ :param dataloader: torch.utils.data.DataLoader
47
+ :param is_train: boolean - True for training model, False for evaluation
48
+ :return: Iterator
49
+ subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
50
+ gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
51
+ tokens - List[Nested.data.dataset.Token] - list of tokens
52
+ valid_len (B x 1) - int - valiud length of each sequence
53
+ logits (B x T x NUM_LABELS) - logits for each token and each tag
54
+ """
55
+ for subwords, gold_tags, tokens, valid_len in dataloader:
56
+ self.model.train(is_train)
57
+
58
+ if torch.cuda.is_available():
59
+ subwords = subwords.cuda()
60
+ gold_tags = gold_tags.cuda()
61
+
62
+ if is_train:
63
+ self.optimizer.zero_grad()
64
+ logits = self.model(subwords)
65
+ else:
66
+ with torch.no_grad():
67
+ logits = self.model(subwords)
68
+
69
+ yield subwords, gold_tags, tokens, valid_len, logits
70
+
71
+ def segments_to_file(self, segments, filename):
72
+ """
73
+ Write segments to file
74
+ :param segments: [List[Nested.data.dataset.Token]] - list of list of tokens
75
+ :param filename: str - output filename
76
+ :return: None
77
+ """
78
+ with open(filename, "w") as fh:
79
+ results = "\n\n".join(["\n".join([t.__str__() for t in segment]) for segment in segments])
80
+ fh.write("Token\tGold Tag\tPredicted Tag\n")
81
+ fh.write(results)
82
+ logging.info("Predictions written to %s", filename)
83
+
84
+ def save(self):
85
+ """
86
+ Save model checkpoint
87
+ :return:
88
+ """
89
+ filename = os.path.join(
90
+ self.output_path,
91
+ "checkpoints",
92
+ "checkpoint_{}.pt".format(self.current_epoch),
93
+ )
94
+
95
+ checkpoint = {
96
+ "model": self.model.state_dict(),
97
+ "optimizer": self.optimizer.state_dict(),
98
+ "epoch": self.current_epoch
99
+ }
100
+
101
+ logger.info("Saving checkpoint to %s", filename)
102
+ torch.save(checkpoint, filename)
103
+
104
+ def load(self, checkpoint_path):
105
+ """
106
+ Load model checkpoint
107
+ :param checkpoint_path: str - path/to/checkpoints
108
+ :return: None
109
+ """
110
+ checkpoint_path = natsort.natsorted(glob.glob(f"{checkpoint_path}/checkpoint_*.pt"))
111
+ checkpoint_path = checkpoint_path[-1]
112
+
113
+ logger.info("Loading checkpoint %s", checkpoint_path)
114
+
115
+ device = None if torch.cuda.is_available() else torch.device('cpu')
116
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
117
+ self.model.load_state_dict(checkpoint["model"])
Nested/trainers/BertNestedTrainer.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import torch
4
+ import numpy as np
5
+ from Nested.trainers import BaseTrainer
6
+ from Nested.utils.metrics import compute_nested_metrics
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class BertNestedTrainer(BaseTrainer):
12
+ def __init__(self, **kwargs):
13
+ super().__init__(**kwargs)
14
+
15
+ def train(self):
16
+ best_val_loss, test_loss = np.inf, np.inf
17
+ num_train_batch = len(self.train_dataloader)
18
+ num_labels = [len(v) for v in self.train_dataloader.dataset.vocab.tags[1:]]
19
+ patience = self.patience
20
+
21
+ for epoch_index in range(self.max_epochs):
22
+ self.current_epoch = epoch_index
23
+ train_loss = 0
24
+
25
+ for batch_index, (subwords, gold_tags, tokens, valid_len, logits) in enumerate(self.tag(
26
+ self.train_dataloader, is_train=True
27
+ ), 1):
28
+ self.current_timestep += 1
29
+
30
+ # Compute loses for each output
31
+ # logits = B x T x L x C
32
+ losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
33
+ torch.reshape(gold_tags[:, i, :], (-1,)).long())
34
+ for i, l in enumerate(num_labels)]
35
+
36
+ torch.autograd.backward(losses)
37
+
38
+ # Avoid exploding gradient by doing gradient clipping
39
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
40
+
41
+ self.optimizer.step()
42
+ self.scheduler.step()
43
+ batch_loss = sum(l.item() for l in losses)
44
+ train_loss += batch_loss
45
+
46
+ if self.current_timestep % self.log_interval == 0:
47
+ logger.info(
48
+ "Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
49
+ epoch_index,
50
+ batch_index,
51
+ num_train_batch,
52
+ self.current_timestep,
53
+ self.optimizer.param_groups[0]['lr'],
54
+ batch_loss
55
+ )
56
+
57
+ train_loss /= num_train_batch
58
+
59
+ logger.info("** Evaluating on validation dataset **")
60
+ val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
61
+ val_metrics = compute_nested_metrics(segments, self.val_dataloader.dataset.transform.vocab.tags[1:])
62
+
63
+ epoch_summary_loss = {
64
+ "train_loss": train_loss,
65
+ "val_loss": val_loss
66
+ }
67
+ epoch_summary_metrics = {
68
+ "val_micro_f1": val_metrics.micro_f1,
69
+ "val_precision": val_metrics.precision,
70
+ "val_recall": val_metrics.recall
71
+ }
72
+
73
+ logger.info(
74
+ "Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
75
+ epoch_index,
76
+ self.current_timestep,
77
+ train_loss,
78
+ val_loss,
79
+ val_metrics.micro_f1
80
+ )
81
+
82
+ if val_loss < best_val_loss:
83
+ patience = self.patience
84
+ best_val_loss = val_loss
85
+ logger.info("** Validation improved, evaluating test data **")
86
+ test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
87
+ self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
88
+ test_metrics = compute_nested_metrics(segments, self.test_dataloader.dataset.transform.vocab.tags[1:])
89
+
90
+ epoch_summary_loss["test_loss"] = test_loss
91
+ epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
92
+ epoch_summary_metrics["test_precision"] = test_metrics.precision
93
+ epoch_summary_metrics["test_recall"] = test_metrics.recall
94
+
95
+ logger.info(
96
+ f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
97
+ epoch_index,
98
+ self.current_timestep,
99
+ test_loss,
100
+ test_metrics.micro_f1
101
+ )
102
+
103
+ self.save()
104
+ else:
105
+ patience -= 1
106
+
107
+ # No improvements, terminating early
108
+ if patience == 0:
109
+ logger.info("Early termination triggered")
110
+ break
111
+
112
+ self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
113
+ self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
114
+
115
+ def tag(self, dataloader, is_train=True):
116
+ """
117
+ Given a dataloader containing segments, predict the tags
118
+ :param dataloader: torch.utils.data.DataLoader
119
+ :param is_train: boolean - True for training model, False for evaluation
120
+ :return: Iterator
121
+ subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
122
+ gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
123
+ tokens - List[Nested.data.dataset.Token] - list of tokens
124
+ valid_len (B x 1) - int - valiud length of each sequence
125
+ logits (B x T x NUM_LABELS) - logits for each token and each tag
126
+ """
127
+ for subwords, gold_tags, tokens, mask, valid_len in dataloader:
128
+ self.model.train(is_train)
129
+
130
+ if torch.cuda.is_available():
131
+ subwords = subwords.cuda()
132
+ gold_tags = gold_tags.cuda()
133
+
134
+ if is_train:
135
+ self.optimizer.zero_grad()
136
+ logits = self.model(subwords)
137
+ else:
138
+ with torch.no_grad():
139
+ logits = self.model(subwords)
140
+
141
+ yield subwords, gold_tags, tokens, valid_len, logits
142
+
143
+ def eval(self, dataloader):
144
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
145
+ num_labels = [len(v) for v in dataloader.dataset.vocab.tags[1:]]
146
+ loss = 0
147
+
148
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
149
+ dataloader, is_train=False
150
+ ):
151
+ losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
152
+ torch.reshape(gold_tags[:, i, :], (-1,)).long())
153
+ for i, l in enumerate(num_labels)]
154
+ loss += sum(losses)
155
+ preds += torch.argmax(logits, dim=3)
156
+ segments += tokens
157
+ valid_lens += list(valid_len)
158
+
159
+ loss /= len(dataloader)
160
+
161
+ # Update segments, attach predicted tags to each token
162
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
163
+
164
+ return preds, segments, valid_lens, loss
165
+
166
+ def infer(self, dataloader):
167
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
168
+
169
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
170
+ dataloader, is_train=False
171
+ ):
172
+ preds += torch.argmax(logits, dim=3)
173
+ segments += tokens
174
+ valid_lens += list(valid_len)
175
+
176
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
177
+ return segments
178
+
179
+ def to_segments(self, segments, preds, valid_lens, vocab):
180
+ if vocab is None:
181
+ vocab = self.vocab
182
+
183
+ tagged_segments = list()
184
+ tokens_stoi = vocab.tokens.get_stoi()
185
+ unk_id = tokens_stoi["UNK"]
186
+
187
+ for segment, pred, valid_len in zip(segments, preds, valid_lens):
188
+ # First, the token at 0th index [CLS] and token at nth index [SEP]
189
+ # Combine the tokens with their corresponding predictions
190
+ segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
191
+
192
+ # Ignore the sub-tokens/subwords, which are identified with text being UNK
193
+ segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
194
+
195
+ # Attach the predicted tags to each token
196
+ list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": vocab.get_itos()[tag_id]}
197
+ for tag_id, vocab in zip(t[1].int().tolist(), vocab.tags[1:])]), segment_pred))
198
+
199
+ # We are only interested in the tagged tokens, we do no longer need raw model predictions
200
+ tagged_segment = [t for t, _ in segment_pred]
201
+ tagged_segments.append(tagged_segment)
202
+
203
+ return tagged_segments
Nested/trainers/BertTrainer.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import torch
4
+ import numpy as np
5
+ from Nested.trainers import BaseTrainer
6
+ from Nested.utils.metrics import compute_single_label_metrics
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class BertTrainer(BaseTrainer):
12
+ def __init__(self, **kwargs):
13
+ super().__init__(**kwargs)
14
+
15
+ def train(self):
16
+ best_val_loss, test_loss = np.inf, np.inf
17
+ num_train_batch = len(self.train_dataloader)
18
+ patience = self.patience
19
+
20
+ for epoch_index in range(self.max_epochs):
21
+ self.current_epoch = epoch_index
22
+ train_loss = 0
23
+
24
+ for batch_index, (_, gold_tags, _, _, logits) in enumerate(self.tag(
25
+ self.train_dataloader, is_train=True
26
+ ), 1):
27
+ self.current_timestep += 1
28
+ batch_loss = self.loss(logits.view(-1, logits.shape[-1]), gold_tags.view(-1))
29
+ batch_loss.backward()
30
+
31
+ # Avoid exploding gradient by doing gradient clipping
32
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
33
+
34
+ self.optimizer.step()
35
+ self.scheduler.step()
36
+ train_loss += batch_loss.item()
37
+
38
+ if self.current_timestep % self.log_interval == 0:
39
+ logger.info(
40
+ "Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
41
+ epoch_index,
42
+ batch_index,
43
+ num_train_batch,
44
+ self.current_timestep,
45
+ self.optimizer.param_groups[0]['lr'],
46
+ batch_loss.item()
47
+ )
48
+
49
+ train_loss /= num_train_batch
50
+
51
+ logger.info("** Evaluating on validation dataset **")
52
+ val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
53
+ val_metrics = compute_single_label_metrics(segments)
54
+
55
+ epoch_summary_loss = {
56
+ "train_loss": train_loss,
57
+ "val_loss": val_loss
58
+ }
59
+ epoch_summary_metrics = {
60
+ "val_micro_f1": val_metrics.micro_f1,
61
+ "val_precision": val_metrics.precision,
62
+ "val_recall": val_metrics.recall
63
+ }
64
+
65
+ logger.info(
66
+ "Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
67
+ epoch_index,
68
+ self.current_timestep,
69
+ train_loss,
70
+ val_loss,
71
+ val_metrics.micro_f1
72
+ )
73
+
74
+ if val_loss < best_val_loss:
75
+ patience = self.patience
76
+ best_val_loss = val_loss
77
+ logger.info("** Validation improved, evaluating test data **")
78
+ test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
79
+ self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
80
+ test_metrics = compute_single_label_metrics(segments)
81
+
82
+ epoch_summary_loss["test_loss"] = test_loss
83
+ epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
84
+ epoch_summary_metrics["test_precision"] = test_metrics.precision
85
+ epoch_summary_metrics["test_recall"] = test_metrics.recall
86
+
87
+ logger.info(
88
+ f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
89
+ epoch_index,
90
+ self.current_timestep,
91
+ test_loss,
92
+ test_metrics.micro_f1
93
+ )
94
+
95
+ self.save()
96
+ else:
97
+ patience -= 1
98
+
99
+ # No improvements, terminating early
100
+ if patience == 0:
101
+ logger.info("Early termination triggered")
102
+ break
103
+
104
+ self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
105
+ self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
106
+
107
+ def eval(self, dataloader):
108
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
109
+ loss = 0
110
+
111
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
112
+ dataloader, is_train=False
113
+ ):
114
+ loss += self.loss(logits.view(-1, logits.shape[-1]), gold_tags.view(-1))
115
+ preds += torch.argmax(logits, dim=2).detach().cpu().numpy().tolist()
116
+ segments += tokens
117
+ valid_lens += list(valid_len)
118
+
119
+ loss /= len(dataloader)
120
+
121
+ # Update segments, attach predicted tags to each token
122
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
123
+
124
+ return preds, segments, valid_lens, loss.item()
125
+
126
+ def infer(self, dataloader):
127
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
128
+
129
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
130
+ dataloader, is_train=False
131
+ ):
132
+ preds += torch.argmax(logits, dim=2).detach().cpu().numpy().tolist()
133
+ segments += tokens
134
+ valid_lens += list(valid_len)
135
+
136
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
137
+ return segments
138
+
139
+ def to_segments(self, segments, preds, valid_lens, vocab):
140
+ if vocab is None:
141
+ vocab = self.vocab
142
+
143
+ tagged_segments = list()
144
+ tokens_stoi = vocab.tokens.get_stoi()
145
+ tags_itos = vocab.tags[0].get_itos()
146
+ unk_id = tokens_stoi["UNK"]
147
+
148
+ for segment, pred, valid_len in zip(segments, preds, valid_lens):
149
+ # First, the token at 0th index [CLS] and token at nth index [SEP]
150
+ # Combine the tokens with their corresponding predictions
151
+ segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
152
+
153
+ # Ignore the sub-tokens/subwords, which are identified with text being UNK
154
+ segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
155
+
156
+ # Attach the predicted tags to each token
157
+ list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": tags_itos[t[1]]}]), segment_pred))
158
+
159
+ # We are only interested in the tagged tokens, we do no longer need raw model predictions
160
+ tagged_segment = [t for t, _ in segment_pred]
161
+ tagged_segments.append(tagged_segment)
162
+
163
+ return tagged_segments
Nested/trainers/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from Nested.trainers.BaseTrainer import BaseTrainer
2
+ from Nested.trainers.BertTrainer import BertTrainer
3
+ from Nested.trainers.BertNestedTrainer import BertNestedTrainer
Nested/trainers/__pycache__/BaseTrainer.cpython-311.pyc ADDED
Binary file (6.45 kB). View file
 
Nested/trainers/__pycache__/BertNestedTrainer.cpython-311.pyc ADDED
Binary file (13.4 kB). View file
 
Nested/trainers/__pycache__/BertTrainer.cpython-311.pyc ADDED
Binary file (9.43 kB). View file
 
Nested/trainers/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (405 Bytes). View file
 
Nested/utils/__init__.py ADDED
File without changes
Nested/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (155 Bytes). View file
 
Nested/utils/__pycache__/data.cpython-311.pyc ADDED
Binary file (8.66 kB). View file
 
Nested/utils/__pycache__/helpers.cpython-311.pyc ADDED
Binary file (5.9 kB). View file
 
Nested/utils/__pycache__/metrics.cpython-311.pyc ADDED
Binary file (5.45 kB). View file
 
Nested/utils/data.py CHANGED
@@ -1,7 +1,16 @@
1
- from collections import Counter
 
 
 
 
 
 
 
 
 
2
 
3
  class Vocab:
4
- def _init_(self, counter, specials=[]) -> None:
5
  self.itos = list(counter.keys()) + specials
6
  self.stoi = {s: i for i, s in enumerate(self.itos)}
7
  self.word_count = counter
@@ -12,44 +21,77 @@ class Vocab:
12
  def get_stoi(self) -> dict[str, int]:
13
  return self.stoi
14
 
15
- def _len_(self):
16
  return len(self.itos)
17
 
18
 
19
- class Token:
20
- def __init__(self, text=None, pred_tag=None, gold_tag=None):
21
- """
22
- Token object to hold token attributes
23
- :param text: str
24
- :param pred_tag: str
25
- :param gold_tag: str
26
- """
27
- self.text = text
28
- self.gold_tag = gold_tag
29
- self.pred_tag = pred_tag
30
- self.subwords = None
31
- @property
32
- def subwords(self):
33
- return self._subwords
34
- @subwords.setter
35
- def subwords(self, value):
36
- self._subwords = value
37
- def __str__(self):
38
- """
39
- Token text representation
40
- :return: str
41
- """
42
- gold_tags = "|".join(self.gold_tag)
43
- if self.pred_tag:
44
- pred_tags = "|".join([pred_tag["tag"] for pred_tag in self.pred_tag])
45
- else:
46
- pred_tags = ""
47
- if self.gold_tag:
48
- r = f"{self.text}\t{gold_tags}\t{pred_tags}"
49
- else:
50
- r = f"{self.text}\t{pred_tags}"
51
- return r
52
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  def text2segments(text):
55
  """
@@ -57,6 +99,38 @@ def text2segments(text):
57
  """
58
  dataset = [[Token(text=token, gold_tag=["O"]) for token in text.split()]]
59
  tokens = [token.text for segment in dataset for token in segment]
 
60
  # Generate vocabs for the tokens
61
  segment_vocab = Vocab(Counter(tokens), specials=["UNK"])
62
- return dataset, segment_vocab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ from collections import Counter, namedtuple
3
+ import logging
4
+ import re
5
+ import itertools
6
+ from Nested.utils.helpers import load_object
7
+ from Nested.data.datasets import Token
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
 
12
  class Vocab:
13
+ def __init__(self, counter, specials=[]) -> None:
14
  self.itos = list(counter.keys()) + specials
15
  self.stoi = {s: i for i, s in enumerate(self.itos)}
16
  self.word_count = counter
 
21
  def get_stoi(self) -> dict[str, int]:
22
  return self.stoi
23
 
24
+ def __len__(self):
25
  return len(self.itos)
26
 
27
 
28
+ def conll_to_segments(filename):
29
+ """
30
+ Convert CoNLL files to segments. This return list of segments and each segment is
31
+ a list of tuples (token, tag)
32
+ :param filename: Path
33
+ :return: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]]
34
+ """
35
+ segments, segment = list(), list()
36
+
37
+ with open(filename, "r") as fh:
38
+ for token in fh.read().splitlines():
39
+ if not token.strip():
40
+ segments.append(segment)
41
+ segment = list()
42
+ else:
43
+ parts = token.split()
44
+ token = Token(text=parts[0], gold_tag=parts[1:])
45
+ segment.append(token)
46
+
47
+ segments.append(segment)
48
+
49
+ return segments
50
+
51
+
52
+ def parse_conll_files(data_paths):
53
+ """
54
+ Parse CoNLL formatted files and return list of segments for each file and index
55
+ the vocabs and tags across all data_paths
56
+ :param data_paths: tuple(Path) - tuple of filenames
57
+ :return: tuple( [[(token, tag), ...], [(token, tag), ...]], -> segments for data_paths[i]
58
+ [[(token, tag), ...], [(token, tag), ...]], -> segments for data_paths[i+1],
59
+ ...
60
+ )
61
+ List of segments for each dataset and each segment has list of (tokens, tags)
62
+ """
63
+ vocabs = namedtuple("Vocab", ["tags", "tokens"])
64
+ datasets, tags, tokens = list(), list(), list()
65
+
66
+ for data_path in data_paths:
67
+ dataset = conll_to_segments(data_path)
68
+ datasets.append(dataset)
69
+ tokens += [token.text for segment in dataset for token in segment]
70
+ tags += [token.gold_tag for segment in dataset for token in segment]
71
+
72
+ # Flatten list of tags
73
+ tags = list(itertools.chain(*tags))
74
+
75
+ # Generate vocabs for tags and tokens
76
+ tag_vocabs = tag_vocab_by_type(tags)
77
+ tag_vocabs.insert(0, Vocab(Counter(tags)))
78
+ vocabs = vocabs(tokens=Vocab(Counter(tokens), specials=["UNK"]), tags=tag_vocabs)
79
+ return tuple(datasets), vocabs
80
+
81
+
82
+ def tag_vocab_by_type(tags):
83
+ vocabs = list()
84
+ c = Counter(tags)
85
+ tag_names = c.keys()
86
+ tag_types = sorted(list(set([tag.split("-", 1)[1] for tag in tag_names if "-" in tag])))
87
+
88
+ for tag_type in tag_types:
89
+ r = re.compile(".*-" + tag_type + "$")
90
+ t = list(filter(r.match, tags)) + ["O"]
91
+ vocabs.append(Vocab(Counter(t)))
92
+
93
+ return vocabs
94
+
95
 
96
  def text2segments(text):
97
  """
 
99
  """
100
  dataset = [[Token(text=token, gold_tag=["O"]) for token in text.split()]]
101
  tokens = [token.text for segment in dataset for token in segment]
102
+
103
  # Generate vocabs for the tokens
104
  segment_vocab = Vocab(Counter(tokens), specials=["UNK"])
105
+ return dataset, segment_vocab
106
+
107
+
108
+ def get_dataloaders(
109
+ datasets, vocab, data_config, batch_size=32, num_workers=0, shuffle=(True, False, False)
110
+ ):
111
+ """
112
+ From the datasets generate the dataloaders
113
+ :param datasets: list - list of the datasets, list of list of segments and tokens
114
+ :param batch_size: int
115
+ :param num_workers: int
116
+ :param shuffle: boolean - to shuffle the data or not
117
+ :return: List[torch.utils.data.DataLoader]
118
+ """
119
+ dataloaders = list()
120
+
121
+ for i, examples in enumerate(datasets):
122
+ data_config["kwargs"].update({"examples": examples, "vocab": vocab})
123
+ dataset = load_object(data_config["fn"], data_config["kwargs"])
124
+
125
+ dataloader = DataLoader(
126
+ dataset=dataset,
127
+ shuffle=shuffle[i],
128
+ batch_size=batch_size,
129
+ num_workers=num_workers,
130
+ collate_fn=dataset.collate_fn,
131
+ )
132
+
133
+ logger.info("%s batches found", len(dataloader))
134
+ dataloaders.append(dataloader)
135
+
136
+ return dataloaders
Nested/utils/helpers.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ import importlib
5
+ import shutil
6
+ import torch
7
+ import pickle
8
+ import json
9
+ import random
10
+ import numpy as np
11
+ from argparse import Namespace
12
+
13
+
14
+ def logging_config(log_file=None):
15
+ """
16
+ Initialize custom logger
17
+ :param log_file: str - path to log file, full path
18
+ :return: None
19
+ """
20
+ handlers = [logging.StreamHandler(sys.stdout)]
21
+
22
+ if log_file:
23
+ handlers.append(logging.FileHandler(log_file, "w", "utf-8"))
24
+ print("Logging to {}".format(log_file))
25
+
26
+ logging.basicConfig(
27
+ level=logging.INFO,
28
+ handlers=handlers,
29
+ format="%(levelname)s\t%(name)s\t%(asctime)s\t%(message)s",
30
+ datefmt="%a, %d %b %Y %H:%M:%S",
31
+ force=True
32
+ )
33
+
34
+
35
+ def load_object(name, kwargs):
36
+ """
37
+ Load objects dynamically given the object name and its arguments
38
+ :param name: str - object name, class name or function name
39
+ :param kwargs: dict - keyword arguments
40
+ :return: object
41
+ """
42
+ object_module, object_name = name.rsplit(".", 1)
43
+ object_module = importlib.import_module(object_module)
44
+ fn = getattr(object_module, object_name)(**kwargs)
45
+ return fn
46
+
47
+
48
+ def make_output_dirs(path, subdirs=[], overwrite=True):
49
+ """
50
+ Create root directory and any other sub-directories
51
+ :param path: str - root directory
52
+ :param subdirs: List[str] - list of sub-directories
53
+ :param overwrite: boolean - to overwrite the directory or not
54
+ :return: None
55
+ """
56
+ if overwrite:
57
+ shutil.rmtree(path, ignore_errors=True)
58
+
59
+ os.makedirs(path)
60
+
61
+ for subdir in subdirs:
62
+ os.makedirs(os.path.join(path, subdir))
63
+
64
+
65
+ def load_checkpoint(model_path):
66
+ """
67
+ Load model given the model path
68
+ :param model_path: str - path to model
69
+ :return: tagger - Nested.trainers.BaseTrainer - the tagger model
70
+ vocab - arabicner.utils.data.Vocab - indexed tags
71
+ train_config - argparse.Namespace - training configurations
72
+ """
73
+ with open(os.path.join(model_path, "tag_vocab.pkl"), "rb") as fh:
74
+ tag_vocab = pickle.load(fh)
75
+
76
+ # Load train configurations from checkpoint
77
+ train_config = Namespace()
78
+ with open(os.path.join(model_path, "args.json"), "r") as fh:
79
+ train_config.__dict__ = json.load(fh)
80
+
81
+ # Initialize the loss function, not used for inference, but evaluation
82
+ loss = load_object(train_config.loss["fn"], train_config.loss["kwargs"])
83
+
84
+ # Load BERT tagger
85
+ model = load_object(train_config.network_config["fn"], train_config.network_config["kwargs"])
86
+ model = torch.nn.DataParallel(model)
87
+
88
+ if torch.cuda.is_available():
89
+ model = model.cuda()
90
+
91
+ # Update arguments for the tagger
92
+ # Attach the model, loss (used for evaluations cases)
93
+ train_config.trainer_config["kwargs"]["model"] = model
94
+ train_config.trainer_config["kwargs"]["loss"] = loss
95
+
96
+ tagger = load_object(train_config.trainer_config["fn"], train_config.trainer_config["kwargs"])
97
+ tagger.load(os.path.join(model_path, "checkpoints"))
98
+ return tagger, tag_vocab, train_config
99
+
100
+
101
+ def set_seed(seed):
102
+ """
103
+ Set the seed for random intialization and set
104
+ CUDANN parameters to ensure determmihstic results across
105
+ multiple runs with the same seed
106
+
107
+ :param seed: int
108
+ """
109
+ np.random.seed(seed)
110
+ random.seed(seed)
111
+ torch.manual_seed(seed)
112
+ torch.cuda.manual_seed(seed)
113
+ torch.cuda.manual_seed_all(seed)
114
+
115
+ torch.backends.cudnn.deterministic = True
116
+ torch.backends.cudnn.benchmark = False
117
+ torch.backends.cudnn.enabled = False
Nested/utils/metrics.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from seqeval.metrics import (
2
+ classification_report,
3
+ precision_score,
4
+ recall_score,
5
+ f1_score,
6
+ accuracy_score,
7
+ )
8
+ from seqeval.scheme import IOB2
9
+ from types import SimpleNamespace
10
+ import logging
11
+ import re
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def compute_nested_metrics(segments, vocabs):
17
+ """
18
+ Compute metrics for nested NER
19
+ :param segments: List[List[Nested.data.dataset.Token]] - list of segments
20
+ :return: metrics - SimpleNamespace - F1/micro/macro/weights, recall, precision, accuracy
21
+ """
22
+ y, y_hat = list(), list()
23
+
24
+ # We duplicate the dataset N times, where N is the number of entity types
25
+ # For each copy, we create y and y_hat
26
+ # Example: first copy, will create pairs of ground truth and predicted labels for entity type GPE
27
+ # another copy will create pairs for LOC, etc.
28
+ for i, vocab in enumerate(vocabs):
29
+ vocab_tags = [tag for tag in vocab.get_itos() if "-" in tag]
30
+ r = re.compile("|".join(vocab_tags))
31
+
32
+ y += [[(list(filter(r.match, token.gold_tag)) or ["O"])[0] for token in segment] for segment in segments]
33
+ y_hat += [[token.pred_tag[i]["tag"] for token in segment] for segment in segments]
34
+
35
+ logging.info("\n" + classification_report(y, y_hat, scheme=IOB2, digits=4))
36
+
37
+ metrics = {
38
+ "micro_f1": f1_score(y, y_hat, average="micro", scheme=IOB2),
39
+ "macro_f1": f1_score(y, y_hat, average="macro", scheme=IOB2),
40
+ "weights_f1": f1_score(y, y_hat, average="weighted", scheme=IOB2),
41
+ "precision": precision_score(y, y_hat, scheme=IOB2),
42
+ "recall": recall_score(y, y_hat, scheme=IOB2),
43
+ "accuracy": accuracy_score(y, y_hat),
44
+ }
45
+
46
+ return SimpleNamespace(**metrics)
47
+
48
+
49
+ def compute_single_label_metrics(segments):
50
+ """
51
+ Compute metrics for flat NER
52
+ :param segments: List[List[Nested.data.dataset.Token]] - list of segments
53
+ :return: metrics - SimpleNamespace - F1/micro/macro/weights, recall, precision, accuracy
54
+ """
55
+ y = [[token.gold_tag[0] for token in segment] for segment in segments]
56
+ y_hat = [[token.pred_tag[0]["tag"] for token in segment] for segment in segments]
57
+
58
+ logging.info("\n" + classification_report(y, y_hat, scheme=IOB2, digits=4))
59
+
60
+ metrics = {
61
+ "micro_f1": f1_score(y, y_hat, average="micro", scheme=IOB2),
62
+ "macro_f1": f1_score(y, y_hat, average="macro", scheme=IOB2),
63
+ "weights_f1": f1_score(y, y_hat, average="weighted", scheme=IOB2),
64
+ "precision": precision_score(y, y_hat, scheme=IOB2),
65
+ "recall": recall_score(y, y_hat, scheme=IOB2),
66
+ "accuracy": accuracy_score(y, y_hat),
67
+ }
68
+
69
+ return SimpleNamespace(**metrics)