Spaces:
Running
Running
Upload 37 files
Browse files- Nested/__init__.py +0 -0
- Nested/__pycache__/__init__.cpython-311.pyc +0 -0
- Nested/bin/__init__.py +0 -0
- Nested/bin/eval.py +87 -0
- Nested/bin/infer.py +73 -0
- Nested/bin/process.py +140 -0
- Nested/bin/train.py +222 -0
- Nested/data/__init__.py +0 -0
- Nested/data/__pycache__/__init__.cpython-311.pyc +0 -0
- Nested/data/__pycache__/datasets.cpython-311.pyc +0 -0
- Nested/data/__pycache__/transforms.cpython-311.pyc +0 -0
- Nested/data/datasets.py +150 -0
- Nested/data/transforms.py +127 -0
- Nested/nn/BaseModel.py +22 -0
- Nested/nn/BertNestedTagger.py +34 -0
- Nested/nn/BertSeqTagger.py +4 -1
- Nested/nn/__init__.py +3 -0
- Nested/nn/__pycache__/BaseModel.cpython-311.pyc +0 -0
- Nested/nn/__pycache__/BertNestedTagger.cpython-311.pyc +0 -0
- Nested/nn/__pycache__/BertSeqTagger.cpython-311.pyc +0 -0
- Nested/nn/__pycache__/__init__.cpython-311.pyc +0 -0
- Nested/trainers/BaseTrainer.py +117 -0
- Nested/trainers/BertNestedTrainer.py +203 -0
- Nested/trainers/BertTrainer.py +163 -0
- Nested/trainers/__init__.py +3 -0
- Nested/trainers/__pycache__/BaseTrainer.cpython-311.pyc +0 -0
- Nested/trainers/__pycache__/BertNestedTrainer.cpython-311.pyc +0 -0
- Nested/trainers/__pycache__/BertTrainer.cpython-311.pyc +0 -0
- Nested/trainers/__pycache__/__init__.cpython-311.pyc +0 -0
- Nested/utils/__init__.py +0 -0
- Nested/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- Nested/utils/__pycache__/data.cpython-311.pyc +0 -0
- Nested/utils/__pycache__/helpers.cpython-311.pyc +0 -0
- Nested/utils/__pycache__/metrics.cpython-311.pyc +0 -0
- Nested/utils/data.py +112 -38
- Nested/utils/helpers.py +117 -0
- Nested/utils/metrics.py +69 -0
Nested/__init__.py
ADDED
|
File without changes
|
Nested/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (149 Bytes). View file
|
|
|
Nested/bin/__init__.py
ADDED
|
File without changes
|
Nested/bin/eval.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import argparse
|
| 4 |
+
from collections import namedtuple
|
| 5 |
+
from Nested.utils.helpers import load_checkpoint, make_output_dirs, logging_config
|
| 6 |
+
from Nested.utils.data import get_dataloaders, parse_conll_files
|
| 7 |
+
from Nested.utils.metrics import compute_single_label_metrics, compute_nested_metrics
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def parse_args():
|
| 13 |
+
parser = argparse.ArgumentParser(
|
| 14 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
parser.add_argument(
|
| 18 |
+
"--output_path",
|
| 19 |
+
type=str,
|
| 20 |
+
required=True,
|
| 21 |
+
help="Path to save results",
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
parser.add_argument(
|
| 25 |
+
"--model_path",
|
| 26 |
+
type=str,
|
| 27 |
+
required=True,
|
| 28 |
+
help="Model path",
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--data_paths",
|
| 33 |
+
nargs="+",
|
| 34 |
+
type=str,
|
| 35 |
+
required=True,
|
| 36 |
+
help="Text or sequence to tag, this is in same format as training data with 'O' tag for all tokens",
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--batch_size",
|
| 41 |
+
type=int,
|
| 42 |
+
default=32,
|
| 43 |
+
help="Batch size",
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
args = parser.parse_args()
|
| 47 |
+
|
| 48 |
+
return args
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def main(args):
|
| 52 |
+
# Create directory to save predictions
|
| 53 |
+
make_output_dirs(args.output_path, overwrite=True)
|
| 54 |
+
logging_config(log_file=os.path.join(args.output_path, "eval.log"))
|
| 55 |
+
|
| 56 |
+
# Load tagger
|
| 57 |
+
tagger, tag_vocab, train_config = load_checkpoint(args.model_path)
|
| 58 |
+
|
| 59 |
+
# Convert text to a tagger dataset and index the tokens in args.text
|
| 60 |
+
datasets, vocab = parse_conll_files(args.data_paths)
|
| 61 |
+
|
| 62 |
+
vocabs = namedtuple("Vocab", ["tags", "tokens"])
|
| 63 |
+
vocab = vocabs(tokens=vocab.tokens, tags=tag_vocab)
|
| 64 |
+
|
| 65 |
+
# From the datasets generate the dataloaders
|
| 66 |
+
dataloaders = get_dataloaders(
|
| 67 |
+
datasets, vocab,
|
| 68 |
+
train_config.data_config,
|
| 69 |
+
batch_size=args.batch_size,
|
| 70 |
+
shuffle=[False] * len(datasets)
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Evaluate the model on each dataloader
|
| 74 |
+
for dataloader, input_file in zip(dataloaders, args.data_paths):
|
| 75 |
+
filename = os.path.basename(input_file)
|
| 76 |
+
predictions_file = os.path.join(args.output_path, f"predictions_{filename}")
|
| 77 |
+
_, segments, _, _ = tagger.eval(dataloader)
|
| 78 |
+
tagger.segments_to_file(segments, predictions_file)
|
| 79 |
+
|
| 80 |
+
if "Nested" in train_config.trainer_config["fn"]:
|
| 81 |
+
compute_nested_metrics(segments, vocab.tags[1:])
|
| 82 |
+
else:
|
| 83 |
+
compute_single_label_metrics(segments)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
main(parse_args())
|
Nested/bin/infer.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import argparse
|
| 3 |
+
from collections import namedtuple
|
| 4 |
+
from Nested.utils.helpers import load_checkpoint
|
| 5 |
+
from Nested.utils.data import get_dataloaders, text2segments
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def parse_args():
|
| 11 |
+
parser = argparse.ArgumentParser(
|
| 12 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
parser.add_argument(
|
| 16 |
+
"--model_path",
|
| 17 |
+
type=str,
|
| 18 |
+
required=True,
|
| 19 |
+
help="Model path",
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--text",
|
| 24 |
+
type=str,
|
| 25 |
+
required=True,
|
| 26 |
+
help="Text or sequence to tag",
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"--batch_size",
|
| 31 |
+
type=int,
|
| 32 |
+
default=32,
|
| 33 |
+
help="Batch size",
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
args = parser.parse_args()
|
| 37 |
+
|
| 38 |
+
return args
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def main(args):
|
| 42 |
+
# Load tagger
|
| 43 |
+
tagger, tag_vocab, train_config = load_checkpoint(args.model_path)
|
| 44 |
+
|
| 45 |
+
# Convert text to a tagger dataset and index the tokens in args.text
|
| 46 |
+
dataset, token_vocab = text2segments(args.text)
|
| 47 |
+
|
| 48 |
+
vocabs = namedtuple("Vocab", ["tags", "tokens"])
|
| 49 |
+
vocab = vocabs(tokens=token_vocab, tags=tag_vocab)
|
| 50 |
+
|
| 51 |
+
# From the datasets generate the dataloaders
|
| 52 |
+
dataloader = get_dataloaders(
|
| 53 |
+
(dataset,),
|
| 54 |
+
vocab,
|
| 55 |
+
train_config.data_config,
|
| 56 |
+
batch_size=args.batch_size,
|
| 57 |
+
shuffle=(False,),
|
| 58 |
+
)[0]
|
| 59 |
+
|
| 60 |
+
# Perform inference on the text and get back the tagged segments
|
| 61 |
+
segments = tagger.infer(dataloader)
|
| 62 |
+
|
| 63 |
+
# Print results
|
| 64 |
+
for segment in segments:
|
| 65 |
+
s = [
|
| 66 |
+
f"{token.text} ({'|'.join([t['tag'] for t in token.pred_tag])})"
|
| 67 |
+
for token in segment
|
| 68 |
+
]
|
| 69 |
+
print(" ".join(s))
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
if __name__ == "__main__":
|
| 73 |
+
main(parse_args())
|
Nested/bin/process.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import csv
|
| 4 |
+
import logging
|
| 5 |
+
import numpy as np
|
| 6 |
+
from Nested.utils.helpers import logging_config
|
| 7 |
+
from Nested.utils.data import conll_to_segments
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def to_conll_format(input_files, output_path, multi_label=False):
|
| 13 |
+
"""
|
| 14 |
+
Parse data files and convert them into CoNLL format
|
| 15 |
+
:param input_files: List[str] - list of filenames
|
| 16 |
+
:param output_path: str - output path
|
| 17 |
+
:param multi_label: boolean - True to process data with mutli-class/multi-label
|
| 18 |
+
:return:
|
| 19 |
+
"""
|
| 20 |
+
for input_file in input_files:
|
| 21 |
+
tokens = list()
|
| 22 |
+
prev_sent_id = None
|
| 23 |
+
|
| 24 |
+
with open(input_file, "r") as fh:
|
| 25 |
+
r = csv.reader(fh, delimiter="\t", quotechar=" ")
|
| 26 |
+
next(r)
|
| 27 |
+
|
| 28 |
+
for row in r:
|
| 29 |
+
sent_id, token, labels = row[1], row[3], row[4].split()
|
| 30 |
+
valid_labels = sum([1 for l in labels if "-" in l or l == "O"]) == len(labels)
|
| 31 |
+
|
| 32 |
+
if not valid_labels:
|
| 33 |
+
logging.warning("Invalid labels found %s", str(row))
|
| 34 |
+
continue
|
| 35 |
+
if not labels:
|
| 36 |
+
logging.warning("Token %s has no label", str(row))
|
| 37 |
+
continue
|
| 38 |
+
if not token:
|
| 39 |
+
logging.warning("Token %s is missing", str(row))
|
| 40 |
+
continue
|
| 41 |
+
if len(token.split()) > 1:
|
| 42 |
+
logging.warning("Token %s has multiple tokens", str(row))
|
| 43 |
+
continue
|
| 44 |
+
|
| 45 |
+
if prev_sent_id is not None and sent_id != prev_sent_id:
|
| 46 |
+
tokens.append([])
|
| 47 |
+
|
| 48 |
+
if multi_label:
|
| 49 |
+
tokens.append([token] + labels)
|
| 50 |
+
else:
|
| 51 |
+
tokens.append([token, labels[0]])
|
| 52 |
+
|
| 53 |
+
prev_sent_id = sent_id
|
| 54 |
+
|
| 55 |
+
num_segments = sum([1 for token in tokens if not token])
|
| 56 |
+
logging.info("Found %d segments and %d tokens in %s", num_segments + 1, len(tokens) - num_segments, input_file)
|
| 57 |
+
|
| 58 |
+
filename = os.path.basename(input_file)
|
| 59 |
+
output_file = os.path.join(output_path, filename)
|
| 60 |
+
|
| 61 |
+
with open(output_file, "w") as fh:
|
| 62 |
+
fh.write("\n".join(" ".join(token) for token in tokens))
|
| 63 |
+
logging.info("Output file %s", output_file)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def train_dev_test_split(input_files, output_path, train_ratio, dev_ratio):
|
| 67 |
+
segments = list()
|
| 68 |
+
filenames = ["train.txt", "val.txt", "test.txt"]
|
| 69 |
+
|
| 70 |
+
for input_file in input_files:
|
| 71 |
+
segments += conll_to_segments(input_file)
|
| 72 |
+
|
| 73 |
+
n = len(segments)
|
| 74 |
+
np.random.shuffle(segments)
|
| 75 |
+
datasets = np.split(segments, [int(train_ratio*n), int((train_ratio+dev_ratio)*n)])
|
| 76 |
+
|
| 77 |
+
# write data to files
|
| 78 |
+
for i in range(len(datasets)):
|
| 79 |
+
filename = os.path.join(output_path, filenames[i])
|
| 80 |
+
|
| 81 |
+
with open(filename, "w") as fh:
|
| 82 |
+
text = "\n\n".join(["\n".join([f"{token.text} {' '.join(token.gold_tag)}" for token in segment]) for segment in datasets[i]])
|
| 83 |
+
fh.write(text)
|
| 84 |
+
logging.info("Output file %s", filename)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def main(args):
|
| 88 |
+
if args.task == "to_conll_format":
|
| 89 |
+
to_conll_format(args.input_files, args.output_path, multi_label=args.multi_label)
|
| 90 |
+
if args.task == "train_dev_test_split":
|
| 91 |
+
train_dev_test_split(args.input_files, args.output_path, args.train_ratio, args.dev_ratio)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
parser = argparse.ArgumentParser(
|
| 96 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--input_files",
|
| 101 |
+
type=str,
|
| 102 |
+
nargs="+",
|
| 103 |
+
required=True,
|
| 104 |
+
help="List of input files",
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
parser.add_argument(
|
| 108 |
+
"--output_path",
|
| 109 |
+
type=str,
|
| 110 |
+
required=True,
|
| 111 |
+
help="Output path",
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--train_ratio",
|
| 116 |
+
type=float,
|
| 117 |
+
required=False,
|
| 118 |
+
help="Training data ratio (percent of segments). Required with the task train_dev_test_split. "
|
| 119 |
+
"Files must in ConLL format",
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--dev_ratio",
|
| 124 |
+
type=float,
|
| 125 |
+
required=False,
|
| 126 |
+
help="Dev/val data ratio (percent of segments). Required with the task train_dev_test_split. "
|
| 127 |
+
"Files must in ConLL format",
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--task", required=True, choices=["to_conll_format", "train_dev_test_split"]
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
parser.add_argument(
|
| 135 |
+
"--multi_label", action='store_true'
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
args = parser.parse_args()
|
| 139 |
+
logging_config(os.path.join(args.output_path, "process.log"))
|
| 140 |
+
main(args)
|
Nested/bin/train.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import json
|
| 4 |
+
import argparse
|
| 5 |
+
import torch.utils.tensorboard
|
| 6 |
+
from torchvision import *
|
| 7 |
+
import pickle
|
| 8 |
+
from Nested.utils.data import get_dataloaders, parse_conll_files
|
| 9 |
+
from Nested.utils.helpers import logging_config, load_object, make_output_dirs, set_seed
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def parse_args():
|
| 15 |
+
parser = argparse.ArgumentParser(
|
| 16 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
"--output_path",
|
| 21 |
+
type=str,
|
| 22 |
+
required=True,
|
| 23 |
+
help="Output path",
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--train_path",
|
| 28 |
+
type=str,
|
| 29 |
+
required=True,
|
| 30 |
+
help="Path to training data",
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--val_path",
|
| 35 |
+
type=str,
|
| 36 |
+
required=True,
|
| 37 |
+
help="Path to training data",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--test_path",
|
| 42 |
+
type=str,
|
| 43 |
+
required=True,
|
| 44 |
+
help="Path to training data",
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--bert_model",
|
| 49 |
+
type=str,
|
| 50 |
+
default="aubmindlab/bert-base-arabertv2",
|
| 51 |
+
help="BERT model",
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--gpus",
|
| 56 |
+
type=int,
|
| 57 |
+
nargs="+",
|
| 58 |
+
default=[0],
|
| 59 |
+
help="GPU IDs to train on",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--log_interval",
|
| 64 |
+
type=int,
|
| 65 |
+
default=10,
|
| 66 |
+
help="Log results every that many timesteps",
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--batch_size",
|
| 71 |
+
type=int,
|
| 72 |
+
default=32,
|
| 73 |
+
help="Batch size",
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--num_workers",
|
| 78 |
+
type=int,
|
| 79 |
+
default=0,
|
| 80 |
+
help="Dataloader number of workers",
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--data_config",
|
| 85 |
+
type=json.loads,
|
| 86 |
+
default='{"fn": "Nested.data.datasets.DefaultDataset", "kwargs": {"max_seq_len": 512}}',
|
| 87 |
+
help="Dataset configurations",
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--trainer_config",
|
| 92 |
+
type=json.loads,
|
| 93 |
+
default='{"fn": "Nested.trainers.BertTrainer", "kwargs": {"max_epochs": 50}}',
|
| 94 |
+
help="Trainer configurations",
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"--network_config",
|
| 99 |
+
type=json.loads,
|
| 100 |
+
default='{"fn": "Nested.nn.BertSeqTagger", "kwargs": '
|
| 101 |
+
'{"dropout": 0.1, "bert_model": "aubmindlab/bert-base-arabertv2"}}',
|
| 102 |
+
help="Network configurations",
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--optimizer",
|
| 107 |
+
type=json.loads,
|
| 108 |
+
default='{"fn": "torch.optim.AdamW", "kwargs": {"lr": 0.0001}}',
|
| 109 |
+
help="Optimizer configurations",
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--lr_scheduler",
|
| 114 |
+
type=json.loads,
|
| 115 |
+
default='{"fn": "torch.optim.lr_scheduler.ExponentialLR", "kwargs": {"gamma": 1}}',
|
| 116 |
+
help="Learning rate scheduler configurations",
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--loss",
|
| 121 |
+
type=json.loads,
|
| 122 |
+
default='{"fn": "torch.nn.CrossEntropyLoss", "kwargs": {}}',
|
| 123 |
+
help="Loss function configurations",
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--overwrite",
|
| 128 |
+
action="store_true",
|
| 129 |
+
help="Overwrite output directory",
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--seed",
|
| 134 |
+
type=int,
|
| 135 |
+
default=1,
|
| 136 |
+
help="Seed for random initialization",
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
args = parser.parse_args()
|
| 140 |
+
|
| 141 |
+
return args
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def main(args):
|
| 145 |
+
make_output_dirs(
|
| 146 |
+
args.output_path,
|
| 147 |
+
subdirs=("tensorboard", "checkpoints"),
|
| 148 |
+
overwrite=args.overwrite,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Set the seed for randomization
|
| 152 |
+
set_seed(args.seed)
|
| 153 |
+
|
| 154 |
+
logging_config(os.path.join(args.output_path, "train.log"))
|
| 155 |
+
summary_writer = torch.utils.tensorboard.SummaryWriter(
|
| 156 |
+
os.path.join(args.output_path, "tensorboard")
|
| 157 |
+
)
|
| 158 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(gpu) for gpu in args.gpus])
|
| 159 |
+
|
| 160 |
+
# Get the datasets and vocab for tags and tokens
|
| 161 |
+
datasets, vocab = parse_conll_files((args.train_path, args.val_path, args.test_path))
|
| 162 |
+
|
| 163 |
+
if "Nested" in args.network_config["fn"]:
|
| 164 |
+
args.network_config["kwargs"]["num_labels"] = [len(v) for v in vocab.tags[1:]]
|
| 165 |
+
else:
|
| 166 |
+
args.network_config["kwargs"]["num_labels"] = len(vocab.tags[0])
|
| 167 |
+
|
| 168 |
+
args.data_config["kwargs"]["bert_model"] = args.network_config["kwargs"]["bert_model"]
|
| 169 |
+
|
| 170 |
+
# Save tag vocab to desk
|
| 171 |
+
with open(os.path.join(args.output_path, "tag_vocab.pkl"), "wb") as fh:
|
| 172 |
+
pickle.dump(vocab.tags, fh)
|
| 173 |
+
|
| 174 |
+
# Write config to file
|
| 175 |
+
args_file = os.path.join(args.output_path, "args.json")
|
| 176 |
+
with open(args_file, "w") as fh:
|
| 177 |
+
logger.info("Writing config to %s", args_file)
|
| 178 |
+
json.dump(args.__dict__, fh, indent=4)
|
| 179 |
+
|
| 180 |
+
# From the datasets generate the dataloaders
|
| 181 |
+
train_dataloader, val_dataloader, test_dataloader = get_dataloaders(
|
| 182 |
+
datasets, vocab, args.data_config, args.batch_size, args.num_workers
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
model = load_object(args.network_config["fn"], args.network_config["kwargs"])
|
| 186 |
+
model = torch.nn.DataParallel(model, device_ids=range(len(args.gpus)))
|
| 187 |
+
|
| 188 |
+
if torch.cuda.is_available():
|
| 189 |
+
model = model.cuda()
|
| 190 |
+
|
| 191 |
+
args.optimizer["kwargs"]["params"] = model.parameters()
|
| 192 |
+
optimizer = load_object(args.optimizer["fn"], args.optimizer["kwargs"])
|
| 193 |
+
|
| 194 |
+
args.lr_scheduler["kwargs"]["optimizer"] = optimizer
|
| 195 |
+
if "num_training_steps" in args.lr_scheduler["kwargs"]:
|
| 196 |
+
args.lr_scheduler["kwargs"]["num_training_steps"] = args.max_epochs * len(
|
| 197 |
+
train_dataloader
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
scheduler = load_object(args.lr_scheduler["fn"], args.lr_scheduler["kwargs"])
|
| 201 |
+
loss = load_object(args.loss["fn"], args.loss["kwargs"])
|
| 202 |
+
|
| 203 |
+
args.trainer_config["kwargs"].update({
|
| 204 |
+
"model": model,
|
| 205 |
+
"optimizer": optimizer,
|
| 206 |
+
"scheduler": scheduler,
|
| 207 |
+
"loss": loss,
|
| 208 |
+
"train_dataloader": train_dataloader,
|
| 209 |
+
"val_dataloader": val_dataloader,
|
| 210 |
+
"test_dataloader": test_dataloader,
|
| 211 |
+
"log_interval": args.log_interval,
|
| 212 |
+
"summary_writer": summary_writer,
|
| 213 |
+
"output_path": args.output_path
|
| 214 |
+
})
|
| 215 |
+
|
| 216 |
+
trainer = load_object(args.trainer_config["fn"], args.trainer_config["kwargs"])
|
| 217 |
+
trainer.train()
|
| 218 |
+
return
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
if __name__ == "__main__":
|
| 222 |
+
main(parse_args())
|
Nested/data/__init__.py
ADDED
|
File without changes
|
Nested/data/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (154 Bytes). View file
|
|
|
Nested/data/__pycache__/datasets.cpython-311.pyc
ADDED
|
Binary file (7.31 kB). View file
|
|
|
Nested/data/__pycache__/transforms.cpython-311.pyc
ADDED
|
Binary file (9.26 kB). View file
|
|
|
Nested/data/datasets.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import Dataset
|
| 4 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 5 |
+
from Nested.data.transforms import (
|
| 6 |
+
BertSeqTransform,
|
| 7 |
+
NestedTagsTransform
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Token:
|
| 14 |
+
def __init__(self, text=None, pred_tag=None, gold_tag=None):
|
| 15 |
+
"""
|
| 16 |
+
Token object to hold token attributes
|
| 17 |
+
:param text: str
|
| 18 |
+
:param pred_tag: str
|
| 19 |
+
:param gold_tag: str
|
| 20 |
+
"""
|
| 21 |
+
self.text = text
|
| 22 |
+
self.gold_tag = gold_tag
|
| 23 |
+
self.pred_tag = pred_tag
|
| 24 |
+
self.subwords = None
|
| 25 |
+
|
| 26 |
+
@property
|
| 27 |
+
def subwords(self):
|
| 28 |
+
return self._subwords
|
| 29 |
+
|
| 30 |
+
@subwords.setter
|
| 31 |
+
def subwords(self, value):
|
| 32 |
+
self._subwords = value
|
| 33 |
+
|
| 34 |
+
def __str__(self):
|
| 35 |
+
"""
|
| 36 |
+
Token text representation
|
| 37 |
+
:return: str
|
| 38 |
+
"""
|
| 39 |
+
gold_tags = "|".join(self.gold_tag)
|
| 40 |
+
|
| 41 |
+
if self.pred_tag:
|
| 42 |
+
pred_tags = "|".join([pred_tag["tag"] for pred_tag in self.pred_tag])
|
| 43 |
+
else:
|
| 44 |
+
pred_tags = ""
|
| 45 |
+
|
| 46 |
+
if self.gold_tag:
|
| 47 |
+
r = f"{self.text}\t{gold_tags}\t{pred_tags}"
|
| 48 |
+
else:
|
| 49 |
+
r = f"{self.text}\t{pred_tags}"
|
| 50 |
+
|
| 51 |
+
return r
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class DefaultDataset(Dataset):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
examples=None,
|
| 58 |
+
vocab=None,
|
| 59 |
+
bert_model="aubmindlab/bert-base-arabertv2",
|
| 60 |
+
max_seq_len=512,
|
| 61 |
+
):
|
| 62 |
+
"""
|
| 63 |
+
The dataset that used to transform the segments into training data
|
| 64 |
+
:param examples: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]]
|
| 65 |
+
You can get generate examples from -- Nested.data.dataset.parse_conll_files
|
| 66 |
+
:param vocab: vocab object containing indexed tags and tokens
|
| 67 |
+
:param bert_model: str - BERT model
|
| 68 |
+
:param: int - maximum sequence length
|
| 69 |
+
"""
|
| 70 |
+
self.transform = BertSeqTransform(bert_model, vocab, max_seq_len=max_seq_len)
|
| 71 |
+
self.examples = examples
|
| 72 |
+
self.vocab = vocab
|
| 73 |
+
|
| 74 |
+
def __len__(self):
|
| 75 |
+
return len(self.examples)
|
| 76 |
+
|
| 77 |
+
def __getitem__(self, item):
|
| 78 |
+
subwords, tags, tokens, valid_len = self.transform(self.examples[item])
|
| 79 |
+
return subwords, tags, tokens, valid_len
|
| 80 |
+
|
| 81 |
+
def collate_fn(self, batch):
|
| 82 |
+
"""
|
| 83 |
+
Collate function that is called when the batch is called by the trainer
|
| 84 |
+
:param batch: Dataloader batch
|
| 85 |
+
:return: Same output as the __getitem__ function
|
| 86 |
+
"""
|
| 87 |
+
subwords, tags, tokens, valid_len = zip(*batch)
|
| 88 |
+
|
| 89 |
+
# Pad sequences in this batch
|
| 90 |
+
# subwords and tokens are padded with zeros
|
| 91 |
+
# tags are padding with the index of the O tag
|
| 92 |
+
subwords = pad_sequence(subwords, batch_first=True, padding_value=0)
|
| 93 |
+
tags = pad_sequence(
|
| 94 |
+
tags, batch_first=True, padding_value=self.vocab.tags[0].get_stoi()["O"]
|
| 95 |
+
)
|
| 96 |
+
return subwords, tags, tokens, valid_len
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class NestedTagsDataset(Dataset):
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
examples=None,
|
| 103 |
+
vocab=None,
|
| 104 |
+
bert_model="aubmindlab/bert-base-arabertv2",
|
| 105 |
+
max_seq_len=512,
|
| 106 |
+
):
|
| 107 |
+
"""
|
| 108 |
+
The dataset that used to transform the segments into training data
|
| 109 |
+
:param examples: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]]
|
| 110 |
+
You can get generate examples from -- Nested.data.dataset.parse_conll_files
|
| 111 |
+
:param vocab: vocab object containing indexed tags and tokens
|
| 112 |
+
:param bert_model: str - BERT model
|
| 113 |
+
:param: int - maximum sequence length
|
| 114 |
+
"""
|
| 115 |
+
self.transform = NestedTagsTransform(
|
| 116 |
+
bert_model, vocab, max_seq_len=max_seq_len
|
| 117 |
+
)
|
| 118 |
+
self.examples = examples
|
| 119 |
+
self.vocab = vocab
|
| 120 |
+
|
| 121 |
+
def __len__(self):
|
| 122 |
+
return len(self.examples)
|
| 123 |
+
|
| 124 |
+
def __getitem__(self, item):
|
| 125 |
+
subwords, tags, tokens, masks, valid_len = self.transform(self.examples[item])
|
| 126 |
+
return subwords, tags, tokens, masks, valid_len
|
| 127 |
+
|
| 128 |
+
def collate_fn(self, batch):
|
| 129 |
+
"""
|
| 130 |
+
Collate function that is called when the batch is called by the trainer
|
| 131 |
+
:param batch: Dataloader batch
|
| 132 |
+
:return: Same output as the __getitem__ function
|
| 133 |
+
"""
|
| 134 |
+
subwords, tags, tokens, masks, valid_len = zip(*batch)
|
| 135 |
+
|
| 136 |
+
# Pad sequences in this batch
|
| 137 |
+
# subwords and tokens are padded with zeros
|
| 138 |
+
# tags are padding with the index of the O tag
|
| 139 |
+
subwords = pad_sequence(subwords, batch_first=True, padding_value=0)
|
| 140 |
+
|
| 141 |
+
masks = [torch.nn.ConstantPad1d((0, subwords.shape[-1] - tag.shape[-1]), 0)(mask)
|
| 142 |
+
for tag, mask in zip(tags, masks)]
|
| 143 |
+
masks = torch.cat(masks)
|
| 144 |
+
|
| 145 |
+
# Pad the tags, do the padding for each tag type
|
| 146 |
+
tags = [torch.nn.ConstantPad1d((0, subwords.shape[-1] - tag.shape[-1]), vocab.get_stoi()["O"])(tag)
|
| 147 |
+
for tag, vocab in zip(tags, self.vocab.tags[1:])]
|
| 148 |
+
tags = torch.cat(tags)
|
| 149 |
+
|
| 150 |
+
return subwords, tags, tokens, masks, valid_len
|
Nested/data/transforms.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import BertTokenizer
|
| 3 |
+
from functools import partial
|
| 4 |
+
import logging
|
| 5 |
+
import re
|
| 6 |
+
import itertools
|
| 7 |
+
import Nested
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BertSeqTransform:
|
| 13 |
+
def __init__(self, bert_model, vocab, max_seq_len=512):
|
| 14 |
+
self.tokenizer = BertTokenizer.from_pretrained(bert_model)
|
| 15 |
+
self.encoder = partial(
|
| 16 |
+
self.tokenizer.encode,
|
| 17 |
+
max_length=max_seq_len,
|
| 18 |
+
truncation=True,
|
| 19 |
+
)
|
| 20 |
+
self.max_seq_len = max_seq_len
|
| 21 |
+
self.vocab = vocab
|
| 22 |
+
|
| 23 |
+
def __call__(self, segment):
|
| 24 |
+
subwords, tags, tokens = list(), list(), list()
|
| 25 |
+
unk_token = Nested.data.datasets.Token(text="UNK")
|
| 26 |
+
|
| 27 |
+
for token in segment:
|
| 28 |
+
# Sometimes the tokenizer fails to encode the word and return no input_ids, in that case, we use
|
| 29 |
+
# the input_id for [UNK]
|
| 30 |
+
token_subwords = self.encoder(token.text)[1:-1] or self.encoder("[UNK]")[1:-1]
|
| 31 |
+
subwords += token_subwords
|
| 32 |
+
tags += [self.vocab.tags[0].get_stoi()[token.gold_tag[0]]] + [self.vocab.tags[0].get_stoi()["O"]] * (len(token_subwords) - 1)
|
| 33 |
+
tokens += [token] + [unk_token] * (len(token_subwords) - 1)
|
| 34 |
+
|
| 35 |
+
# Truncate to max_seq_len
|
| 36 |
+
if len(subwords) > self.max_seq_len - 2:
|
| 37 |
+
text = " ".join([t.text for t in tokens if t.text != "UNK"])
|
| 38 |
+
logger.info("Truncating the sequence %s to %d", text, self.max_seq_len - 2)
|
| 39 |
+
subwords = subwords[:self.max_seq_len - 2]
|
| 40 |
+
tags = tags[:self.max_seq_len - 2]
|
| 41 |
+
tokens = tokens[:self.max_seq_len - 2]
|
| 42 |
+
|
| 43 |
+
subwords.insert(0, self.tokenizer.cls_token_id)
|
| 44 |
+
subwords.append(self.tokenizer.sep_token_id)
|
| 45 |
+
|
| 46 |
+
tags.insert(0, self.vocab.tags[0].get_stoi()["O"])
|
| 47 |
+
tags.append(self.vocab.tags[0].get_stoi()["O"])
|
| 48 |
+
|
| 49 |
+
tokens.insert(0, unk_token)
|
| 50 |
+
tokens.append(unk_token)
|
| 51 |
+
|
| 52 |
+
return torch.LongTensor(subwords), torch.LongTensor(tags), tokens, len(tokens)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class NestedTagsTransform:
|
| 56 |
+
def __init__(self, bert_model, vocab, max_seq_len=512):
|
| 57 |
+
self.tokenizer = BertTokenizer.from_pretrained(bert_model)
|
| 58 |
+
self.encoder = partial(
|
| 59 |
+
self.tokenizer.encode,
|
| 60 |
+
max_length=max_seq_len,
|
| 61 |
+
truncation=True,
|
| 62 |
+
)
|
| 63 |
+
self.max_seq_len = max_seq_len
|
| 64 |
+
self.vocab = vocab
|
| 65 |
+
|
| 66 |
+
def __call__(self, segment):
|
| 67 |
+
tags, tokens, subwords = list(), list(), list()
|
| 68 |
+
unk_token = Nested.data.datasets.Token(text="UNK")
|
| 69 |
+
|
| 70 |
+
# Encode each token and get its subwords and IDs
|
| 71 |
+
for token in segment:
|
| 72 |
+
# Sometimes the tokenizer fails to encode the word and return no input_ids, in that case, we use
|
| 73 |
+
# the input_id for [UNK]
|
| 74 |
+
token.subwords = self.encoder(token.text)[1:-1] or self.encoder("[UNK]")[1:-1]
|
| 75 |
+
subwords += token.subwords
|
| 76 |
+
tokens += [token] + [unk_token] * (len(token.subwords) - 1)
|
| 77 |
+
|
| 78 |
+
# Construct the labels for each tag type
|
| 79 |
+
# The sequence will have a list of tags for each type
|
| 80 |
+
# The final tags for a sequence is a matrix NUM_TAG_TYPES x SEQ_LEN
|
| 81 |
+
# Example:
|
| 82 |
+
# [
|
| 83 |
+
# [O, O, B-PERS, I-PERS, O, O, O]
|
| 84 |
+
# [B-ORG, I-ORG, O, O, O, O, O]
|
| 85 |
+
# [O, O, O, O, O, O, B-GPE]
|
| 86 |
+
# ]
|
| 87 |
+
for vocab in self.vocab.tags[1:]:
|
| 88 |
+
vocab_tags = "|".join(["^" + t + "$" for t in vocab.get_itos() if "-" in t])
|
| 89 |
+
r = re.compile(vocab_tags)
|
| 90 |
+
|
| 91 |
+
# This is really messy
|
| 92 |
+
# For a given token we find a matching tag_name, BUT we might find
|
| 93 |
+
# multiple matches (i.e. a token can be labeled B-ORG and I-ORG) in this
|
| 94 |
+
# case we get only the first tag as we do not have overlapping of same type
|
| 95 |
+
single_type_tags = [[(list(filter(r.match, token.gold_tag))
|
| 96 |
+
or ["O"])[0]] + ["O"] * (len(token.subwords) - 1)
|
| 97 |
+
for token in segment]
|
| 98 |
+
single_type_tags = list(itertools.chain(*single_type_tags))
|
| 99 |
+
tags.append([vocab.get_stoi()[tag] for tag in single_type_tags])
|
| 100 |
+
|
| 101 |
+
# Truncate to max_seq_len
|
| 102 |
+
if len(subwords) > self.max_seq_len - 2:
|
| 103 |
+
text = " ".join([t.text for t in tokens if t.text != "UNK"])
|
| 104 |
+
logger.info("Truncating the sequence %s to %d", text, self.max_seq_len - 2)
|
| 105 |
+
subwords = subwords[:self.max_seq_len - 2]
|
| 106 |
+
tags = [t[:self.max_seq_len - 2] for t in tags]
|
| 107 |
+
tokens = tokens[:self.max_seq_len - 2]
|
| 108 |
+
|
| 109 |
+
# Add dummy token at the start end of sequence
|
| 110 |
+
tokens.insert(0, unk_token)
|
| 111 |
+
tokens.append(unk_token)
|
| 112 |
+
|
| 113 |
+
# Add CLS and SEP at start end of subwords
|
| 114 |
+
subwords.insert(0, self.tokenizer.cls_token_id)
|
| 115 |
+
subwords.append(self.tokenizer.sep_token_id)
|
| 116 |
+
subwords = torch.LongTensor(subwords)
|
| 117 |
+
|
| 118 |
+
# Add "O" tags for the first and last subwords
|
| 119 |
+
tags = torch.Tensor(tags)
|
| 120 |
+
tags = torch.column_stack((
|
| 121 |
+
torch.Tensor([vocab.get_stoi()["O"] for vocab in self.vocab.tags[1:]]),
|
| 122 |
+
tags,
|
| 123 |
+
torch.Tensor([vocab.get_stoi()["O"] for vocab in self.vocab.tags[1:]]),
|
| 124 |
+
)).unsqueeze(0)
|
| 125 |
+
|
| 126 |
+
mask = torch.ones_like(tags)
|
| 127 |
+
return subwords, tags, tokens, mask, len(tokens)
|
Nested/nn/BaseModel.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
from transformers import BertModel
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BaseModel(nn.Module):
|
| 9 |
+
def __init__(self,
|
| 10 |
+
bert_model="aubmindlab/bert-base-arabertv2",
|
| 11 |
+
num_labels=2,
|
| 12 |
+
dropout=0.1,
|
| 13 |
+
num_types=0):
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
self.bert_model = bert_model
|
| 17 |
+
self.num_labels = num_labels
|
| 18 |
+
self.num_types = num_types
|
| 19 |
+
self.dropout = dropout
|
| 20 |
+
|
| 21 |
+
self.bert = BertModel.from_pretrained(bert_model)
|
| 22 |
+
self.dropout = nn.Dropout(dropout)
|
Nested/nn/BertNestedTagger.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from Nested.nn import BaseModel
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class BertNestedTagger(BaseModel):
|
| 7 |
+
def __init__(self, **kwargs):
|
| 8 |
+
super(BertNestedTagger, self).__init__(**kwargs)
|
| 9 |
+
|
| 10 |
+
self.max_num_labels = max(self.num_labels)
|
| 11 |
+
classifiers = [nn.Linear(768, num_labels) for num_labels in self.num_labels]
|
| 12 |
+
self.classifiers = torch.nn.Sequential(*classifiers)
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
y = self.bert(x)
|
| 16 |
+
y = self.dropout(y["last_hidden_state"])
|
| 17 |
+
output = list()
|
| 18 |
+
|
| 19 |
+
for i, classifier in enumerate(self.classifiers):
|
| 20 |
+
logits = classifier(y)
|
| 21 |
+
|
| 22 |
+
# Pad logits to allow Multi-GPU/DataParallel training to work
|
| 23 |
+
# We will truncate the padded dimensions when we compute the loss in the trainer
|
| 24 |
+
logits = torch.nn.ConstantPad1d((0, self.max_num_labels - logits.shape[-1]), 0)(logits)
|
| 25 |
+
output.append(logits)
|
| 26 |
+
|
| 27 |
+
# Return tensor of the shape B x T x L x C
|
| 28 |
+
# B: batch size
|
| 29 |
+
# T: sequence length
|
| 30 |
+
# L: number of tag types
|
| 31 |
+
# C: number of classes per tag type
|
| 32 |
+
output = torch.stack(output).permute((1, 2, 0, 3))
|
| 33 |
+
return output
|
| 34 |
+
|
Nested/nn/BertSeqTagger.py
CHANGED
|
@@ -1,14 +1,17 @@
|
|
| 1 |
import torch.nn as nn
|
| 2 |
from transformers import BertModel
|
| 3 |
|
|
|
|
| 4 |
class BertSeqTagger(nn.Module):
|
| 5 |
def __init__(self, bert_model, num_labels=2, dropout=0.1):
|
| 6 |
super().__init__()
|
|
|
|
| 7 |
self.bert = BertModel.from_pretrained(bert_model)
|
| 8 |
self.dropout = nn.Dropout(dropout)
|
| 9 |
self.linear = nn.Linear(768, num_labels)
|
|
|
|
| 10 |
def forward(self, x):
|
| 11 |
y = self.bert(x)
|
| 12 |
y = self.dropout(y["last_hidden_state"])
|
| 13 |
logits = self.linear(y)
|
| 14 |
-
return logits
|
|
|
|
| 1 |
import torch.nn as nn
|
| 2 |
from transformers import BertModel
|
| 3 |
|
| 4 |
+
|
| 5 |
class BertSeqTagger(nn.Module):
|
| 6 |
def __init__(self, bert_model, num_labels=2, dropout=0.1):
|
| 7 |
super().__init__()
|
| 8 |
+
|
| 9 |
self.bert = BertModel.from_pretrained(bert_model)
|
| 10 |
self.dropout = nn.Dropout(dropout)
|
| 11 |
self.linear = nn.Linear(768, num_labels)
|
| 12 |
+
|
| 13 |
def forward(self, x):
|
| 14 |
y = self.bert(x)
|
| 15 |
y = self.dropout(y["last_hidden_state"])
|
| 16 |
logits = self.linear(y)
|
| 17 |
+
return logits
|
Nested/nn/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from Nested.nn.BaseModel import BaseModel
|
| 2 |
+
from Nested.nn.BertSeqTagger import BertSeqTagger
|
| 3 |
+
from Nested.nn.BertNestedTagger import BertNestedTagger
|
Nested/nn/__pycache__/BaseModel.cpython-311.pyc
ADDED
|
Binary file (1.34 kB). View file
|
|
|
Nested/nn/__pycache__/BertNestedTagger.cpython-311.pyc
ADDED
|
Binary file (2.33 kB). View file
|
|
|
Nested/nn/__pycache__/BertSeqTagger.cpython-311.pyc
ADDED
|
Binary file (1.54 kB). View file
|
|
|
Nested/nn/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (379 Bytes). View file
|
|
|
Nested/trainers/BaseTrainer.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import logging
|
| 4 |
+
import natsort
|
| 5 |
+
import glob
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BaseTrainer:
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
model=None,
|
| 14 |
+
max_epochs=50,
|
| 15 |
+
optimizer=None,
|
| 16 |
+
scheduler=None,
|
| 17 |
+
loss=None,
|
| 18 |
+
train_dataloader=None,
|
| 19 |
+
val_dataloader=None,
|
| 20 |
+
test_dataloader=None,
|
| 21 |
+
log_interval=10,
|
| 22 |
+
summary_writer=None,
|
| 23 |
+
output_path=None,
|
| 24 |
+
clip=5,
|
| 25 |
+
patience=5
|
| 26 |
+
):
|
| 27 |
+
self.model = model
|
| 28 |
+
self.max_epochs = max_epochs
|
| 29 |
+
self.train_dataloader = train_dataloader
|
| 30 |
+
self.val_dataloader = val_dataloader
|
| 31 |
+
self.test_dataloader = test_dataloader
|
| 32 |
+
self.optimizer = optimizer
|
| 33 |
+
self.scheduler = scheduler
|
| 34 |
+
self.loss = loss
|
| 35 |
+
self.log_interval = log_interval
|
| 36 |
+
self.summary_writer = summary_writer
|
| 37 |
+
self.output_path = output_path
|
| 38 |
+
self.current_timestep = 0
|
| 39 |
+
self.current_epoch = 0
|
| 40 |
+
self.clip = clip
|
| 41 |
+
self.patience = patience
|
| 42 |
+
|
| 43 |
+
def tag(self, dataloader, is_train=True):
|
| 44 |
+
"""
|
| 45 |
+
Given a dataloader containing segments, predict the tags
|
| 46 |
+
:param dataloader: torch.utils.data.DataLoader
|
| 47 |
+
:param is_train: boolean - True for training model, False for evaluation
|
| 48 |
+
:return: Iterator
|
| 49 |
+
subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
|
| 50 |
+
gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
|
| 51 |
+
tokens - List[Nested.data.dataset.Token] - list of tokens
|
| 52 |
+
valid_len (B x 1) - int - valiud length of each sequence
|
| 53 |
+
logits (B x T x NUM_LABELS) - logits for each token and each tag
|
| 54 |
+
"""
|
| 55 |
+
for subwords, gold_tags, tokens, valid_len in dataloader:
|
| 56 |
+
self.model.train(is_train)
|
| 57 |
+
|
| 58 |
+
if torch.cuda.is_available():
|
| 59 |
+
subwords = subwords.cuda()
|
| 60 |
+
gold_tags = gold_tags.cuda()
|
| 61 |
+
|
| 62 |
+
if is_train:
|
| 63 |
+
self.optimizer.zero_grad()
|
| 64 |
+
logits = self.model(subwords)
|
| 65 |
+
else:
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
logits = self.model(subwords)
|
| 68 |
+
|
| 69 |
+
yield subwords, gold_tags, tokens, valid_len, logits
|
| 70 |
+
|
| 71 |
+
def segments_to_file(self, segments, filename):
|
| 72 |
+
"""
|
| 73 |
+
Write segments to file
|
| 74 |
+
:param segments: [List[Nested.data.dataset.Token]] - list of list of tokens
|
| 75 |
+
:param filename: str - output filename
|
| 76 |
+
:return: None
|
| 77 |
+
"""
|
| 78 |
+
with open(filename, "w") as fh:
|
| 79 |
+
results = "\n\n".join(["\n".join([t.__str__() for t in segment]) for segment in segments])
|
| 80 |
+
fh.write("Token\tGold Tag\tPredicted Tag\n")
|
| 81 |
+
fh.write(results)
|
| 82 |
+
logging.info("Predictions written to %s", filename)
|
| 83 |
+
|
| 84 |
+
def save(self):
|
| 85 |
+
"""
|
| 86 |
+
Save model checkpoint
|
| 87 |
+
:return:
|
| 88 |
+
"""
|
| 89 |
+
filename = os.path.join(
|
| 90 |
+
self.output_path,
|
| 91 |
+
"checkpoints",
|
| 92 |
+
"checkpoint_{}.pt".format(self.current_epoch),
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
checkpoint = {
|
| 96 |
+
"model": self.model.state_dict(),
|
| 97 |
+
"optimizer": self.optimizer.state_dict(),
|
| 98 |
+
"epoch": self.current_epoch
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
logger.info("Saving checkpoint to %s", filename)
|
| 102 |
+
torch.save(checkpoint, filename)
|
| 103 |
+
|
| 104 |
+
def load(self, checkpoint_path):
|
| 105 |
+
"""
|
| 106 |
+
Load model checkpoint
|
| 107 |
+
:param checkpoint_path: str - path/to/checkpoints
|
| 108 |
+
:return: None
|
| 109 |
+
"""
|
| 110 |
+
checkpoint_path = natsort.natsorted(glob.glob(f"{checkpoint_path}/checkpoint_*.pt"))
|
| 111 |
+
checkpoint_path = checkpoint_path[-1]
|
| 112 |
+
|
| 113 |
+
logger.info("Loading checkpoint %s", checkpoint_path)
|
| 114 |
+
|
| 115 |
+
device = None if torch.cuda.is_available() else torch.device('cpu')
|
| 116 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 117 |
+
self.model.load_state_dict(checkpoint["model"])
|
Nested/trainers/BertNestedTrainer.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from Nested.trainers import BaseTrainer
|
| 6 |
+
from Nested.utils.metrics import compute_nested_metrics
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BertNestedTrainer(BaseTrainer):
|
| 12 |
+
def __init__(self, **kwargs):
|
| 13 |
+
super().__init__(**kwargs)
|
| 14 |
+
|
| 15 |
+
def train(self):
|
| 16 |
+
best_val_loss, test_loss = np.inf, np.inf
|
| 17 |
+
num_train_batch = len(self.train_dataloader)
|
| 18 |
+
num_labels = [len(v) for v in self.train_dataloader.dataset.vocab.tags[1:]]
|
| 19 |
+
patience = self.patience
|
| 20 |
+
|
| 21 |
+
for epoch_index in range(self.max_epochs):
|
| 22 |
+
self.current_epoch = epoch_index
|
| 23 |
+
train_loss = 0
|
| 24 |
+
|
| 25 |
+
for batch_index, (subwords, gold_tags, tokens, valid_len, logits) in enumerate(self.tag(
|
| 26 |
+
self.train_dataloader, is_train=True
|
| 27 |
+
), 1):
|
| 28 |
+
self.current_timestep += 1
|
| 29 |
+
|
| 30 |
+
# Compute loses for each output
|
| 31 |
+
# logits = B x T x L x C
|
| 32 |
+
losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
|
| 33 |
+
torch.reshape(gold_tags[:, i, :], (-1,)).long())
|
| 34 |
+
for i, l in enumerate(num_labels)]
|
| 35 |
+
|
| 36 |
+
torch.autograd.backward(losses)
|
| 37 |
+
|
| 38 |
+
# Avoid exploding gradient by doing gradient clipping
|
| 39 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
|
| 40 |
+
|
| 41 |
+
self.optimizer.step()
|
| 42 |
+
self.scheduler.step()
|
| 43 |
+
batch_loss = sum(l.item() for l in losses)
|
| 44 |
+
train_loss += batch_loss
|
| 45 |
+
|
| 46 |
+
if self.current_timestep % self.log_interval == 0:
|
| 47 |
+
logger.info(
|
| 48 |
+
"Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
|
| 49 |
+
epoch_index,
|
| 50 |
+
batch_index,
|
| 51 |
+
num_train_batch,
|
| 52 |
+
self.current_timestep,
|
| 53 |
+
self.optimizer.param_groups[0]['lr'],
|
| 54 |
+
batch_loss
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
train_loss /= num_train_batch
|
| 58 |
+
|
| 59 |
+
logger.info("** Evaluating on validation dataset **")
|
| 60 |
+
val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
|
| 61 |
+
val_metrics = compute_nested_metrics(segments, self.val_dataloader.dataset.transform.vocab.tags[1:])
|
| 62 |
+
|
| 63 |
+
epoch_summary_loss = {
|
| 64 |
+
"train_loss": train_loss,
|
| 65 |
+
"val_loss": val_loss
|
| 66 |
+
}
|
| 67 |
+
epoch_summary_metrics = {
|
| 68 |
+
"val_micro_f1": val_metrics.micro_f1,
|
| 69 |
+
"val_precision": val_metrics.precision,
|
| 70 |
+
"val_recall": val_metrics.recall
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
logger.info(
|
| 74 |
+
"Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
|
| 75 |
+
epoch_index,
|
| 76 |
+
self.current_timestep,
|
| 77 |
+
train_loss,
|
| 78 |
+
val_loss,
|
| 79 |
+
val_metrics.micro_f1
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
if val_loss < best_val_loss:
|
| 83 |
+
patience = self.patience
|
| 84 |
+
best_val_loss = val_loss
|
| 85 |
+
logger.info("** Validation improved, evaluating test data **")
|
| 86 |
+
test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
|
| 87 |
+
self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
|
| 88 |
+
test_metrics = compute_nested_metrics(segments, self.test_dataloader.dataset.transform.vocab.tags[1:])
|
| 89 |
+
|
| 90 |
+
epoch_summary_loss["test_loss"] = test_loss
|
| 91 |
+
epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
|
| 92 |
+
epoch_summary_metrics["test_precision"] = test_metrics.precision
|
| 93 |
+
epoch_summary_metrics["test_recall"] = test_metrics.recall
|
| 94 |
+
|
| 95 |
+
logger.info(
|
| 96 |
+
f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
|
| 97 |
+
epoch_index,
|
| 98 |
+
self.current_timestep,
|
| 99 |
+
test_loss,
|
| 100 |
+
test_metrics.micro_f1
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
self.save()
|
| 104 |
+
else:
|
| 105 |
+
patience -= 1
|
| 106 |
+
|
| 107 |
+
# No improvements, terminating early
|
| 108 |
+
if patience == 0:
|
| 109 |
+
logger.info("Early termination triggered")
|
| 110 |
+
break
|
| 111 |
+
|
| 112 |
+
self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
|
| 113 |
+
self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
|
| 114 |
+
|
| 115 |
+
def tag(self, dataloader, is_train=True):
|
| 116 |
+
"""
|
| 117 |
+
Given a dataloader containing segments, predict the tags
|
| 118 |
+
:param dataloader: torch.utils.data.DataLoader
|
| 119 |
+
:param is_train: boolean - True for training model, False for evaluation
|
| 120 |
+
:return: Iterator
|
| 121 |
+
subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
|
| 122 |
+
gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
|
| 123 |
+
tokens - List[Nested.data.dataset.Token] - list of tokens
|
| 124 |
+
valid_len (B x 1) - int - valiud length of each sequence
|
| 125 |
+
logits (B x T x NUM_LABELS) - logits for each token and each tag
|
| 126 |
+
"""
|
| 127 |
+
for subwords, gold_tags, tokens, mask, valid_len in dataloader:
|
| 128 |
+
self.model.train(is_train)
|
| 129 |
+
|
| 130 |
+
if torch.cuda.is_available():
|
| 131 |
+
subwords = subwords.cuda()
|
| 132 |
+
gold_tags = gold_tags.cuda()
|
| 133 |
+
|
| 134 |
+
if is_train:
|
| 135 |
+
self.optimizer.zero_grad()
|
| 136 |
+
logits = self.model(subwords)
|
| 137 |
+
else:
|
| 138 |
+
with torch.no_grad():
|
| 139 |
+
logits = self.model(subwords)
|
| 140 |
+
|
| 141 |
+
yield subwords, gold_tags, tokens, valid_len, logits
|
| 142 |
+
|
| 143 |
+
def eval(self, dataloader):
|
| 144 |
+
golds, preds, segments, valid_lens = list(), list(), list(), list()
|
| 145 |
+
num_labels = [len(v) for v in dataloader.dataset.vocab.tags[1:]]
|
| 146 |
+
loss = 0
|
| 147 |
+
|
| 148 |
+
for _, gold_tags, tokens, valid_len, logits in self.tag(
|
| 149 |
+
dataloader, is_train=False
|
| 150 |
+
):
|
| 151 |
+
losses = [self.loss(logits[:, :, i, 0:l].view(-1, logits[:, :, i, 0:l].shape[-1]),
|
| 152 |
+
torch.reshape(gold_tags[:, i, :], (-1,)).long())
|
| 153 |
+
for i, l in enumerate(num_labels)]
|
| 154 |
+
loss += sum(losses)
|
| 155 |
+
preds += torch.argmax(logits, dim=3)
|
| 156 |
+
segments += tokens
|
| 157 |
+
valid_lens += list(valid_len)
|
| 158 |
+
|
| 159 |
+
loss /= len(dataloader)
|
| 160 |
+
|
| 161 |
+
# Update segments, attach predicted tags to each token
|
| 162 |
+
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
|
| 163 |
+
|
| 164 |
+
return preds, segments, valid_lens, loss
|
| 165 |
+
|
| 166 |
+
def infer(self, dataloader):
|
| 167 |
+
golds, preds, segments, valid_lens = list(), list(), list(), list()
|
| 168 |
+
|
| 169 |
+
for _, gold_tags, tokens, valid_len, logits in self.tag(
|
| 170 |
+
dataloader, is_train=False
|
| 171 |
+
):
|
| 172 |
+
preds += torch.argmax(logits, dim=3)
|
| 173 |
+
segments += tokens
|
| 174 |
+
valid_lens += list(valid_len)
|
| 175 |
+
|
| 176 |
+
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
|
| 177 |
+
return segments
|
| 178 |
+
|
| 179 |
+
def to_segments(self, segments, preds, valid_lens, vocab):
|
| 180 |
+
if vocab is None:
|
| 181 |
+
vocab = self.vocab
|
| 182 |
+
|
| 183 |
+
tagged_segments = list()
|
| 184 |
+
tokens_stoi = vocab.tokens.get_stoi()
|
| 185 |
+
unk_id = tokens_stoi["UNK"]
|
| 186 |
+
|
| 187 |
+
for segment, pred, valid_len in zip(segments, preds, valid_lens):
|
| 188 |
+
# First, the token at 0th index [CLS] and token at nth index [SEP]
|
| 189 |
+
# Combine the tokens with their corresponding predictions
|
| 190 |
+
segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
|
| 191 |
+
|
| 192 |
+
# Ignore the sub-tokens/subwords, which are identified with text being UNK
|
| 193 |
+
segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
|
| 194 |
+
|
| 195 |
+
# Attach the predicted tags to each token
|
| 196 |
+
list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": vocab.get_itos()[tag_id]}
|
| 197 |
+
for tag_id, vocab in zip(t[1].int().tolist(), vocab.tags[1:])]), segment_pred))
|
| 198 |
+
|
| 199 |
+
# We are only interested in the tagged tokens, we do no longer need raw model predictions
|
| 200 |
+
tagged_segment = [t for t, _ in segment_pred]
|
| 201 |
+
tagged_segments.append(tagged_segment)
|
| 202 |
+
|
| 203 |
+
return tagged_segments
|
Nested/trainers/BertTrainer.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from Nested.trainers import BaseTrainer
|
| 6 |
+
from Nested.utils.metrics import compute_single_label_metrics
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BertTrainer(BaseTrainer):
|
| 12 |
+
def __init__(self, **kwargs):
|
| 13 |
+
super().__init__(**kwargs)
|
| 14 |
+
|
| 15 |
+
def train(self):
|
| 16 |
+
best_val_loss, test_loss = np.inf, np.inf
|
| 17 |
+
num_train_batch = len(self.train_dataloader)
|
| 18 |
+
patience = self.patience
|
| 19 |
+
|
| 20 |
+
for epoch_index in range(self.max_epochs):
|
| 21 |
+
self.current_epoch = epoch_index
|
| 22 |
+
train_loss = 0
|
| 23 |
+
|
| 24 |
+
for batch_index, (_, gold_tags, _, _, logits) in enumerate(self.tag(
|
| 25 |
+
self.train_dataloader, is_train=True
|
| 26 |
+
), 1):
|
| 27 |
+
self.current_timestep += 1
|
| 28 |
+
batch_loss = self.loss(logits.view(-1, logits.shape[-1]), gold_tags.view(-1))
|
| 29 |
+
batch_loss.backward()
|
| 30 |
+
|
| 31 |
+
# Avoid exploding gradient by doing gradient clipping
|
| 32 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
|
| 33 |
+
|
| 34 |
+
self.optimizer.step()
|
| 35 |
+
self.scheduler.step()
|
| 36 |
+
train_loss += batch_loss.item()
|
| 37 |
+
|
| 38 |
+
if self.current_timestep % self.log_interval == 0:
|
| 39 |
+
logger.info(
|
| 40 |
+
"Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
|
| 41 |
+
epoch_index,
|
| 42 |
+
batch_index,
|
| 43 |
+
num_train_batch,
|
| 44 |
+
self.current_timestep,
|
| 45 |
+
self.optimizer.param_groups[0]['lr'],
|
| 46 |
+
batch_loss.item()
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
train_loss /= num_train_batch
|
| 50 |
+
|
| 51 |
+
logger.info("** Evaluating on validation dataset **")
|
| 52 |
+
val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
|
| 53 |
+
val_metrics = compute_single_label_metrics(segments)
|
| 54 |
+
|
| 55 |
+
epoch_summary_loss = {
|
| 56 |
+
"train_loss": train_loss,
|
| 57 |
+
"val_loss": val_loss
|
| 58 |
+
}
|
| 59 |
+
epoch_summary_metrics = {
|
| 60 |
+
"val_micro_f1": val_metrics.micro_f1,
|
| 61 |
+
"val_precision": val_metrics.precision,
|
| 62 |
+
"val_recall": val_metrics.recall
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
logger.info(
|
| 66 |
+
"Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
|
| 67 |
+
epoch_index,
|
| 68 |
+
self.current_timestep,
|
| 69 |
+
train_loss,
|
| 70 |
+
val_loss,
|
| 71 |
+
val_metrics.micro_f1
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
if val_loss < best_val_loss:
|
| 75 |
+
patience = self.patience
|
| 76 |
+
best_val_loss = val_loss
|
| 77 |
+
logger.info("** Validation improved, evaluating test data **")
|
| 78 |
+
test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
|
| 79 |
+
self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
|
| 80 |
+
test_metrics = compute_single_label_metrics(segments)
|
| 81 |
+
|
| 82 |
+
epoch_summary_loss["test_loss"] = test_loss
|
| 83 |
+
epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
|
| 84 |
+
epoch_summary_metrics["test_precision"] = test_metrics.precision
|
| 85 |
+
epoch_summary_metrics["test_recall"] = test_metrics.recall
|
| 86 |
+
|
| 87 |
+
logger.info(
|
| 88 |
+
f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
|
| 89 |
+
epoch_index,
|
| 90 |
+
self.current_timestep,
|
| 91 |
+
test_loss,
|
| 92 |
+
test_metrics.micro_f1
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
self.save()
|
| 96 |
+
else:
|
| 97 |
+
patience -= 1
|
| 98 |
+
|
| 99 |
+
# No improvements, terminating early
|
| 100 |
+
if patience == 0:
|
| 101 |
+
logger.info("Early termination triggered")
|
| 102 |
+
break
|
| 103 |
+
|
| 104 |
+
self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
|
| 105 |
+
self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
|
| 106 |
+
|
| 107 |
+
def eval(self, dataloader):
|
| 108 |
+
golds, preds, segments, valid_lens = list(), list(), list(), list()
|
| 109 |
+
loss = 0
|
| 110 |
+
|
| 111 |
+
for _, gold_tags, tokens, valid_len, logits in self.tag(
|
| 112 |
+
dataloader, is_train=False
|
| 113 |
+
):
|
| 114 |
+
loss += self.loss(logits.view(-1, logits.shape[-1]), gold_tags.view(-1))
|
| 115 |
+
preds += torch.argmax(logits, dim=2).detach().cpu().numpy().tolist()
|
| 116 |
+
segments += tokens
|
| 117 |
+
valid_lens += list(valid_len)
|
| 118 |
+
|
| 119 |
+
loss /= len(dataloader)
|
| 120 |
+
|
| 121 |
+
# Update segments, attach predicted tags to each token
|
| 122 |
+
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
|
| 123 |
+
|
| 124 |
+
return preds, segments, valid_lens, loss.item()
|
| 125 |
+
|
| 126 |
+
def infer(self, dataloader):
|
| 127 |
+
golds, preds, segments, valid_lens = list(), list(), list(), list()
|
| 128 |
+
|
| 129 |
+
for _, gold_tags, tokens, valid_len, logits in self.tag(
|
| 130 |
+
dataloader, is_train=False
|
| 131 |
+
):
|
| 132 |
+
preds += torch.argmax(logits, dim=2).detach().cpu().numpy().tolist()
|
| 133 |
+
segments += tokens
|
| 134 |
+
valid_lens += list(valid_len)
|
| 135 |
+
|
| 136 |
+
segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
|
| 137 |
+
return segments
|
| 138 |
+
|
| 139 |
+
def to_segments(self, segments, preds, valid_lens, vocab):
|
| 140 |
+
if vocab is None:
|
| 141 |
+
vocab = self.vocab
|
| 142 |
+
|
| 143 |
+
tagged_segments = list()
|
| 144 |
+
tokens_stoi = vocab.tokens.get_stoi()
|
| 145 |
+
tags_itos = vocab.tags[0].get_itos()
|
| 146 |
+
unk_id = tokens_stoi["UNK"]
|
| 147 |
+
|
| 148 |
+
for segment, pred, valid_len in zip(segments, preds, valid_lens):
|
| 149 |
+
# First, the token at 0th index [CLS] and token at nth index [SEP]
|
| 150 |
+
# Combine the tokens with their corresponding predictions
|
| 151 |
+
segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
|
| 152 |
+
|
| 153 |
+
# Ignore the sub-tokens/subwords, which are identified with text being UNK
|
| 154 |
+
segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
|
| 155 |
+
|
| 156 |
+
# Attach the predicted tags to each token
|
| 157 |
+
list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": tags_itos[t[1]]}]), segment_pred))
|
| 158 |
+
|
| 159 |
+
# We are only interested in the tagged tokens, we do no longer need raw model predictions
|
| 160 |
+
tagged_segment = [t for t, _ in segment_pred]
|
| 161 |
+
tagged_segments.append(tagged_segment)
|
| 162 |
+
|
| 163 |
+
return tagged_segments
|
Nested/trainers/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from Nested.trainers.BaseTrainer import BaseTrainer
|
| 2 |
+
from Nested.trainers.BertTrainer import BertTrainer
|
| 3 |
+
from Nested.trainers.BertNestedTrainer import BertNestedTrainer
|
Nested/trainers/__pycache__/BaseTrainer.cpython-311.pyc
ADDED
|
Binary file (6.45 kB). View file
|
|
|
Nested/trainers/__pycache__/BertNestedTrainer.cpython-311.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
Nested/trainers/__pycache__/BertTrainer.cpython-311.pyc
ADDED
|
Binary file (9.43 kB). View file
|
|
|
Nested/trainers/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (405 Bytes). View file
|
|
|
Nested/utils/__init__.py
ADDED
|
File without changes
|
Nested/utils/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (155 Bytes). View file
|
|
|
Nested/utils/__pycache__/data.cpython-311.pyc
ADDED
|
Binary file (8.66 kB). View file
|
|
|
Nested/utils/__pycache__/helpers.cpython-311.pyc
ADDED
|
Binary file (5.9 kB). View file
|
|
|
Nested/utils/__pycache__/metrics.cpython-311.pyc
ADDED
|
Binary file (5.45 kB). View file
|
|
|
Nested/utils/data.py
CHANGED
|
@@ -1,7 +1,16 @@
|
|
| 1 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
class Vocab:
|
| 4 |
-
def
|
| 5 |
self.itos = list(counter.keys()) + specials
|
| 6 |
self.stoi = {s: i for i, s in enumerate(self.itos)}
|
| 7 |
self.word_count = counter
|
|
@@ -12,44 +21,77 @@ class Vocab:
|
|
| 12 |
def get_stoi(self) -> dict[str, int]:
|
| 13 |
return self.stoi
|
| 14 |
|
| 15 |
-
def
|
| 16 |
return len(self.itos)
|
| 17 |
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
def text2segments(text):
|
| 55 |
"""
|
|
@@ -57,6 +99,38 @@ def text2segments(text):
|
|
| 57 |
"""
|
| 58 |
dataset = [[Token(text=token, gold_tag=["O"]) for token in text.split()]]
|
| 59 |
tokens = [token.text for segment in dataset for token in segment]
|
|
|
|
| 60 |
# Generate vocabs for the tokens
|
| 61 |
segment_vocab = Vocab(Counter(tokens), specials=["UNK"])
|
| 62 |
-
return dataset, segment_vocab
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import DataLoader
|
| 2 |
+
from collections import Counter, namedtuple
|
| 3 |
+
import logging
|
| 4 |
+
import re
|
| 5 |
+
import itertools
|
| 6 |
+
from Nested.utils.helpers import load_object
|
| 7 |
+
from Nested.data.datasets import Token
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
|
| 12 |
class Vocab:
|
| 13 |
+
def __init__(self, counter, specials=[]) -> None:
|
| 14 |
self.itos = list(counter.keys()) + specials
|
| 15 |
self.stoi = {s: i for i, s in enumerate(self.itos)}
|
| 16 |
self.word_count = counter
|
|
|
|
| 21 |
def get_stoi(self) -> dict[str, int]:
|
| 22 |
return self.stoi
|
| 23 |
|
| 24 |
+
def __len__(self):
|
| 25 |
return len(self.itos)
|
| 26 |
|
| 27 |
|
| 28 |
+
def conll_to_segments(filename):
|
| 29 |
+
"""
|
| 30 |
+
Convert CoNLL files to segments. This return list of segments and each segment is
|
| 31 |
+
a list of tuples (token, tag)
|
| 32 |
+
:param filename: Path
|
| 33 |
+
:return: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]]
|
| 34 |
+
"""
|
| 35 |
+
segments, segment = list(), list()
|
| 36 |
+
|
| 37 |
+
with open(filename, "r") as fh:
|
| 38 |
+
for token in fh.read().splitlines():
|
| 39 |
+
if not token.strip():
|
| 40 |
+
segments.append(segment)
|
| 41 |
+
segment = list()
|
| 42 |
+
else:
|
| 43 |
+
parts = token.split()
|
| 44 |
+
token = Token(text=parts[0], gold_tag=parts[1:])
|
| 45 |
+
segment.append(token)
|
| 46 |
+
|
| 47 |
+
segments.append(segment)
|
| 48 |
+
|
| 49 |
+
return segments
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def parse_conll_files(data_paths):
|
| 53 |
+
"""
|
| 54 |
+
Parse CoNLL formatted files and return list of segments for each file and index
|
| 55 |
+
the vocabs and tags across all data_paths
|
| 56 |
+
:param data_paths: tuple(Path) - tuple of filenames
|
| 57 |
+
:return: tuple( [[(token, tag), ...], [(token, tag), ...]], -> segments for data_paths[i]
|
| 58 |
+
[[(token, tag), ...], [(token, tag), ...]], -> segments for data_paths[i+1],
|
| 59 |
+
...
|
| 60 |
+
)
|
| 61 |
+
List of segments for each dataset and each segment has list of (tokens, tags)
|
| 62 |
+
"""
|
| 63 |
+
vocabs = namedtuple("Vocab", ["tags", "tokens"])
|
| 64 |
+
datasets, tags, tokens = list(), list(), list()
|
| 65 |
+
|
| 66 |
+
for data_path in data_paths:
|
| 67 |
+
dataset = conll_to_segments(data_path)
|
| 68 |
+
datasets.append(dataset)
|
| 69 |
+
tokens += [token.text for segment in dataset for token in segment]
|
| 70 |
+
tags += [token.gold_tag for segment in dataset for token in segment]
|
| 71 |
+
|
| 72 |
+
# Flatten list of tags
|
| 73 |
+
tags = list(itertools.chain(*tags))
|
| 74 |
+
|
| 75 |
+
# Generate vocabs for tags and tokens
|
| 76 |
+
tag_vocabs = tag_vocab_by_type(tags)
|
| 77 |
+
tag_vocabs.insert(0, Vocab(Counter(tags)))
|
| 78 |
+
vocabs = vocabs(tokens=Vocab(Counter(tokens), specials=["UNK"]), tags=tag_vocabs)
|
| 79 |
+
return tuple(datasets), vocabs
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def tag_vocab_by_type(tags):
|
| 83 |
+
vocabs = list()
|
| 84 |
+
c = Counter(tags)
|
| 85 |
+
tag_names = c.keys()
|
| 86 |
+
tag_types = sorted(list(set([tag.split("-", 1)[1] for tag in tag_names if "-" in tag])))
|
| 87 |
+
|
| 88 |
+
for tag_type in tag_types:
|
| 89 |
+
r = re.compile(".*-" + tag_type + "$")
|
| 90 |
+
t = list(filter(r.match, tags)) + ["O"]
|
| 91 |
+
vocabs.append(Vocab(Counter(t)))
|
| 92 |
+
|
| 93 |
+
return vocabs
|
| 94 |
+
|
| 95 |
|
| 96 |
def text2segments(text):
|
| 97 |
"""
|
|
|
|
| 99 |
"""
|
| 100 |
dataset = [[Token(text=token, gold_tag=["O"]) for token in text.split()]]
|
| 101 |
tokens = [token.text for segment in dataset for token in segment]
|
| 102 |
+
|
| 103 |
# Generate vocabs for the tokens
|
| 104 |
segment_vocab = Vocab(Counter(tokens), specials=["UNK"])
|
| 105 |
+
return dataset, segment_vocab
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_dataloaders(
|
| 109 |
+
datasets, vocab, data_config, batch_size=32, num_workers=0, shuffle=(True, False, False)
|
| 110 |
+
):
|
| 111 |
+
"""
|
| 112 |
+
From the datasets generate the dataloaders
|
| 113 |
+
:param datasets: list - list of the datasets, list of list of segments and tokens
|
| 114 |
+
:param batch_size: int
|
| 115 |
+
:param num_workers: int
|
| 116 |
+
:param shuffle: boolean - to shuffle the data or not
|
| 117 |
+
:return: List[torch.utils.data.DataLoader]
|
| 118 |
+
"""
|
| 119 |
+
dataloaders = list()
|
| 120 |
+
|
| 121 |
+
for i, examples in enumerate(datasets):
|
| 122 |
+
data_config["kwargs"].update({"examples": examples, "vocab": vocab})
|
| 123 |
+
dataset = load_object(data_config["fn"], data_config["kwargs"])
|
| 124 |
+
|
| 125 |
+
dataloader = DataLoader(
|
| 126 |
+
dataset=dataset,
|
| 127 |
+
shuffle=shuffle[i],
|
| 128 |
+
batch_size=batch_size,
|
| 129 |
+
num_workers=num_workers,
|
| 130 |
+
collate_fn=dataset.collate_fn,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
logger.info("%s batches found", len(dataloader))
|
| 134 |
+
dataloaders.append(dataloader)
|
| 135 |
+
|
| 136 |
+
return dataloaders
|
Nested/utils/helpers.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import logging
|
| 4 |
+
import importlib
|
| 5 |
+
import shutil
|
| 6 |
+
import torch
|
| 7 |
+
import pickle
|
| 8 |
+
import json
|
| 9 |
+
import random
|
| 10 |
+
import numpy as np
|
| 11 |
+
from argparse import Namespace
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def logging_config(log_file=None):
|
| 15 |
+
"""
|
| 16 |
+
Initialize custom logger
|
| 17 |
+
:param log_file: str - path to log file, full path
|
| 18 |
+
:return: None
|
| 19 |
+
"""
|
| 20 |
+
handlers = [logging.StreamHandler(sys.stdout)]
|
| 21 |
+
|
| 22 |
+
if log_file:
|
| 23 |
+
handlers.append(logging.FileHandler(log_file, "w", "utf-8"))
|
| 24 |
+
print("Logging to {}".format(log_file))
|
| 25 |
+
|
| 26 |
+
logging.basicConfig(
|
| 27 |
+
level=logging.INFO,
|
| 28 |
+
handlers=handlers,
|
| 29 |
+
format="%(levelname)s\t%(name)s\t%(asctime)s\t%(message)s",
|
| 30 |
+
datefmt="%a, %d %b %Y %H:%M:%S",
|
| 31 |
+
force=True
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def load_object(name, kwargs):
|
| 36 |
+
"""
|
| 37 |
+
Load objects dynamically given the object name and its arguments
|
| 38 |
+
:param name: str - object name, class name or function name
|
| 39 |
+
:param kwargs: dict - keyword arguments
|
| 40 |
+
:return: object
|
| 41 |
+
"""
|
| 42 |
+
object_module, object_name = name.rsplit(".", 1)
|
| 43 |
+
object_module = importlib.import_module(object_module)
|
| 44 |
+
fn = getattr(object_module, object_name)(**kwargs)
|
| 45 |
+
return fn
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def make_output_dirs(path, subdirs=[], overwrite=True):
|
| 49 |
+
"""
|
| 50 |
+
Create root directory and any other sub-directories
|
| 51 |
+
:param path: str - root directory
|
| 52 |
+
:param subdirs: List[str] - list of sub-directories
|
| 53 |
+
:param overwrite: boolean - to overwrite the directory or not
|
| 54 |
+
:return: None
|
| 55 |
+
"""
|
| 56 |
+
if overwrite:
|
| 57 |
+
shutil.rmtree(path, ignore_errors=True)
|
| 58 |
+
|
| 59 |
+
os.makedirs(path)
|
| 60 |
+
|
| 61 |
+
for subdir in subdirs:
|
| 62 |
+
os.makedirs(os.path.join(path, subdir))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def load_checkpoint(model_path):
|
| 66 |
+
"""
|
| 67 |
+
Load model given the model path
|
| 68 |
+
:param model_path: str - path to model
|
| 69 |
+
:return: tagger - Nested.trainers.BaseTrainer - the tagger model
|
| 70 |
+
vocab - arabicner.utils.data.Vocab - indexed tags
|
| 71 |
+
train_config - argparse.Namespace - training configurations
|
| 72 |
+
"""
|
| 73 |
+
with open(os.path.join(model_path, "tag_vocab.pkl"), "rb") as fh:
|
| 74 |
+
tag_vocab = pickle.load(fh)
|
| 75 |
+
|
| 76 |
+
# Load train configurations from checkpoint
|
| 77 |
+
train_config = Namespace()
|
| 78 |
+
with open(os.path.join(model_path, "args.json"), "r") as fh:
|
| 79 |
+
train_config.__dict__ = json.load(fh)
|
| 80 |
+
|
| 81 |
+
# Initialize the loss function, not used for inference, but evaluation
|
| 82 |
+
loss = load_object(train_config.loss["fn"], train_config.loss["kwargs"])
|
| 83 |
+
|
| 84 |
+
# Load BERT tagger
|
| 85 |
+
model = load_object(train_config.network_config["fn"], train_config.network_config["kwargs"])
|
| 86 |
+
model = torch.nn.DataParallel(model)
|
| 87 |
+
|
| 88 |
+
if torch.cuda.is_available():
|
| 89 |
+
model = model.cuda()
|
| 90 |
+
|
| 91 |
+
# Update arguments for the tagger
|
| 92 |
+
# Attach the model, loss (used for evaluations cases)
|
| 93 |
+
train_config.trainer_config["kwargs"]["model"] = model
|
| 94 |
+
train_config.trainer_config["kwargs"]["loss"] = loss
|
| 95 |
+
|
| 96 |
+
tagger = load_object(train_config.trainer_config["fn"], train_config.trainer_config["kwargs"])
|
| 97 |
+
tagger.load(os.path.join(model_path, "checkpoints"))
|
| 98 |
+
return tagger, tag_vocab, train_config
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def set_seed(seed):
|
| 102 |
+
"""
|
| 103 |
+
Set the seed for random intialization and set
|
| 104 |
+
CUDANN parameters to ensure determmihstic results across
|
| 105 |
+
multiple runs with the same seed
|
| 106 |
+
|
| 107 |
+
:param seed: int
|
| 108 |
+
"""
|
| 109 |
+
np.random.seed(seed)
|
| 110 |
+
random.seed(seed)
|
| 111 |
+
torch.manual_seed(seed)
|
| 112 |
+
torch.cuda.manual_seed(seed)
|
| 113 |
+
torch.cuda.manual_seed_all(seed)
|
| 114 |
+
|
| 115 |
+
torch.backends.cudnn.deterministic = True
|
| 116 |
+
torch.backends.cudnn.benchmark = False
|
| 117 |
+
torch.backends.cudnn.enabled = False
|
Nested/utils/metrics.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from seqeval.metrics import (
|
| 2 |
+
classification_report,
|
| 3 |
+
precision_score,
|
| 4 |
+
recall_score,
|
| 5 |
+
f1_score,
|
| 6 |
+
accuracy_score,
|
| 7 |
+
)
|
| 8 |
+
from seqeval.scheme import IOB2
|
| 9 |
+
from types import SimpleNamespace
|
| 10 |
+
import logging
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def compute_nested_metrics(segments, vocabs):
|
| 17 |
+
"""
|
| 18 |
+
Compute metrics for nested NER
|
| 19 |
+
:param segments: List[List[Nested.data.dataset.Token]] - list of segments
|
| 20 |
+
:return: metrics - SimpleNamespace - F1/micro/macro/weights, recall, precision, accuracy
|
| 21 |
+
"""
|
| 22 |
+
y, y_hat = list(), list()
|
| 23 |
+
|
| 24 |
+
# We duplicate the dataset N times, where N is the number of entity types
|
| 25 |
+
# For each copy, we create y and y_hat
|
| 26 |
+
# Example: first copy, will create pairs of ground truth and predicted labels for entity type GPE
|
| 27 |
+
# another copy will create pairs for LOC, etc.
|
| 28 |
+
for i, vocab in enumerate(vocabs):
|
| 29 |
+
vocab_tags = [tag for tag in vocab.get_itos() if "-" in tag]
|
| 30 |
+
r = re.compile("|".join(vocab_tags))
|
| 31 |
+
|
| 32 |
+
y += [[(list(filter(r.match, token.gold_tag)) or ["O"])[0] for token in segment] for segment in segments]
|
| 33 |
+
y_hat += [[token.pred_tag[i]["tag"] for token in segment] for segment in segments]
|
| 34 |
+
|
| 35 |
+
logging.info("\n" + classification_report(y, y_hat, scheme=IOB2, digits=4))
|
| 36 |
+
|
| 37 |
+
metrics = {
|
| 38 |
+
"micro_f1": f1_score(y, y_hat, average="micro", scheme=IOB2),
|
| 39 |
+
"macro_f1": f1_score(y, y_hat, average="macro", scheme=IOB2),
|
| 40 |
+
"weights_f1": f1_score(y, y_hat, average="weighted", scheme=IOB2),
|
| 41 |
+
"precision": precision_score(y, y_hat, scheme=IOB2),
|
| 42 |
+
"recall": recall_score(y, y_hat, scheme=IOB2),
|
| 43 |
+
"accuracy": accuracy_score(y, y_hat),
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
return SimpleNamespace(**metrics)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def compute_single_label_metrics(segments):
|
| 50 |
+
"""
|
| 51 |
+
Compute metrics for flat NER
|
| 52 |
+
:param segments: List[List[Nested.data.dataset.Token]] - list of segments
|
| 53 |
+
:return: metrics - SimpleNamespace - F1/micro/macro/weights, recall, precision, accuracy
|
| 54 |
+
"""
|
| 55 |
+
y = [[token.gold_tag[0] for token in segment] for segment in segments]
|
| 56 |
+
y_hat = [[token.pred_tag[0]["tag"] for token in segment] for segment in segments]
|
| 57 |
+
|
| 58 |
+
logging.info("\n" + classification_report(y, y_hat, scheme=IOB2, digits=4))
|
| 59 |
+
|
| 60 |
+
metrics = {
|
| 61 |
+
"micro_f1": f1_score(y, y_hat, average="micro", scheme=IOB2),
|
| 62 |
+
"macro_f1": f1_score(y, y_hat, average="macro", scheme=IOB2),
|
| 63 |
+
"weights_f1": f1_score(y, y_hat, average="weighted", scheme=IOB2),
|
| 64 |
+
"precision": precision_score(y, y_hat, scheme=IOB2),
|
| 65 |
+
"recall": recall_score(y, y_hat, scheme=IOB2),
|
| 66 |
+
"accuracy": accuracy_score(y, y_hat),
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
return SimpleNamespace(**metrics)
|