Upload 61 files
#8
by
shayekh
- opened
This view is limited to 50 files because it contains too many changes.
See the raw diff here.
- src/datasets/__init__.py +7 -0
- src/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
- src/datasets/__pycache__/toxic_spans_crf_3cls_tokens.cpython-38.pyc +0 -0
- src/datasets/__pycache__/toxic_spans_crf_tokens.cpython-38.pyc +0 -0
- src/datasets/__pycache__/toxic_spans_multi_spans.cpython-38.pyc +0 -0
- src/datasets/__pycache__/toxic_spans_spans.cpython-38.pyc +0 -0
- src/datasets/__pycache__/toxic_spans_tokens.cpython-38.pyc +0 -0
- src/datasets/__pycache__/toxic_spans_tokens_3cls.cpython-38.pyc +0 -0
- src/datasets/__pycache__/toxic_spans_tokens_spans.cpython-38.pyc +0 -0
- src/datasets/toxic_spans_crf_3cls_tokens.py +132 -0
- src/datasets/toxic_spans_crf_tokens.py +111 -0
- src/datasets/toxic_spans_multi_spans.py +237 -0
- src/datasets/toxic_spans_spans.py +238 -0
- src/datasets/toxic_spans_tokens.py +81 -0
- src/datasets/toxic_spans_tokens_3cls.py +102 -0
- src/datasets/toxic_spans_tokens_spans.py +269 -0
- src/models/__init__.py +7 -0
- src/models/__pycache__/__init__.cpython-38.pyc +0 -0
- src/models/__pycache__/auto_models.cpython-38.pyc +0 -0
- src/models/__pycache__/bert_crf_token.cpython-38.pyc +0 -0
- src/models/__pycache__/bert_multi_spans.cpython-38.pyc +0 -0
- src/models/__pycache__/bert_token_spans.cpython-38.pyc +0 -0
- src/models/__pycache__/roberta_crf_token.cpython-38.pyc +0 -0
- src/models/__pycache__/roberta_multi_spans.cpython-38.pyc +0 -0
- src/models/__pycache__/roberta_token_spans.cpython-38.pyc +0 -0
- src/models/auto_models.py +6 -0
- src/models/bert_crf_token.py +72 -0
- src/models/bert_multi_spans.py +84 -0
- src/models/bert_token_spans.py +100 -0
- src/models/roberta_crf_token.py +66 -0
- src/models/roberta_multi_spans.py +82 -0
- src/models/roberta_token_spans.py +97 -0
- src/models/two_layer_nn.py +46 -0
- src/modules/__init__.py +0 -0
- src/modules/__pycache__/__init__.cpython-38.pyc +0 -0
- src/modules/__pycache__/embeddings.cpython-38.pyc +0 -0
- src/modules/__pycache__/preprocessors.cpython-38.pyc +0 -0
- src/modules/__pycache__/tokenizers.cpython-38.pyc +0 -0
- src/modules/activations.py +6 -0
- src/modules/embeddings.py +37 -0
- src/modules/losses.py +6 -0
- src/modules/metrics.py +17 -0
- src/modules/optimizers.py +7 -0
- src/modules/preprocessors.py +112 -0
- src/modules/schedulers.py +14 -0
- src/modules/tokenizers.py +107 -0
- src/trainers/__init__.py +0 -0
- src/trainers/base_trainer.py +563 -0
- src/utils/__init__.py +0 -0
- src/utils/__pycache__/__init__.cpython-38.pyc +0 -0
src/datasets/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.datasets.toxic_spans_tokens import *
|
| 2 |
+
from src.datasets.toxic_spans_tokens_3cls import *
|
| 3 |
+
from src.datasets.toxic_spans_spans import *
|
| 4 |
+
from src.datasets.toxic_spans_tokens_spans import *
|
| 5 |
+
from src.datasets.toxic_spans_multi_spans import *
|
| 6 |
+
from src.datasets.toxic_spans_crf_tokens import *
|
| 7 |
+
from src.datasets.toxic_spans_crf_3cls_tokens import *
|
src/datasets/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (503 Bytes). View file
|
|
|
src/datasets/__pycache__/toxic_spans_crf_3cls_tokens.cpython-38.pyc
ADDED
|
Binary file (2.99 kB). View file
|
|
|
src/datasets/__pycache__/toxic_spans_crf_tokens.cpython-38.pyc
ADDED
|
Binary file (2.76 kB). View file
|
|
|
src/datasets/__pycache__/toxic_spans_multi_spans.cpython-38.pyc
ADDED
|
Binary file (5.55 kB). View file
|
|
|
src/datasets/__pycache__/toxic_spans_spans.cpython-38.pyc
ADDED
|
Binary file (5.2 kB). View file
|
|
|
src/datasets/__pycache__/toxic_spans_tokens.cpython-38.pyc
ADDED
|
Binary file (2.35 kB). View file
|
|
|
src/datasets/__pycache__/toxic_spans_tokens_3cls.cpython-38.pyc
ADDED
|
Binary file (2.59 kB). View file
|
|
|
src/datasets/__pycache__/toxic_spans_tokens_spans.cpython-38.pyc
ADDED
|
Binary file (5.97 kB). View file
|
|
|
src/datasets/toxic_spans_crf_3cls_tokens.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.utils.mapper import configmapper
|
| 2 |
+
from transformers import AutoTokenizer
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@configmapper.map("datasets", "toxic_spans_crf_3cls_tokens")
|
| 8 |
+
class ToxicSpansCRF3ClsTokenDataset:
|
| 9 |
+
def __init__(self, config):
|
| 10 |
+
self.config = config
|
| 11 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 12 |
+
self.config.model_checkpoint_name
|
| 13 |
+
)
|
| 14 |
+
self.dataset = load_dataset("csv", data_files=dict(self.config.train_files))
|
| 15 |
+
self.test_dataset = load_dataset("csv", data_files=dict(self.config.eval_files))
|
| 16 |
+
|
| 17 |
+
self.tokenized_inputs = self.dataset.map(
|
| 18 |
+
self.tokenize_and_align_labels_for_train, batched=True
|
| 19 |
+
)
|
| 20 |
+
self.test_tokenized_inputs = self.test_dataset.map(
|
| 21 |
+
self.tokenize_for_test, batched=True
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def tokenize_and_align_labels_for_train(self, examples):
|
| 25 |
+
tokenized_inputs = self.tokenizer(
|
| 26 |
+
examples["text"], **self.config.tokenizer_params
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# tokenized_inputs["text"] = examples["text"]
|
| 30 |
+
example_spans = []
|
| 31 |
+
labels = []
|
| 32 |
+
prediction_mask = np.zeros_like(np.array(tokenized_inputs["input_ids"]))
|
| 33 |
+
offsets_mapping = tokenized_inputs["offset_mapping"]
|
| 34 |
+
|
| 35 |
+
## Wrong Code
|
| 36 |
+
# for i, offset_mapping in enumerate(offsets_mapping):
|
| 37 |
+
# j = 0
|
| 38 |
+
# while j < len(offset_mapping): # [tok1, tok2, tok3] [(0,5),(1,4),(5,7)]
|
| 39 |
+
# if tokenized_inputs["input_ids"][i][j] in [
|
| 40 |
+
# self.tokenizer.sep_token_id,
|
| 41 |
+
# self.tokenizer.pad_token_id,
|
| 42 |
+
# self.tokenizer.cls_token_id,
|
| 43 |
+
# ]:
|
| 44 |
+
# j = j + 1
|
| 45 |
+
# continue
|
| 46 |
+
# else:
|
| 47 |
+
# k = j + 1
|
| 48 |
+
# while self.tokenizer.convert_ids_to_tokens(
|
| 49 |
+
# tokenized_inputs["input_ids"][i][k]
|
| 50 |
+
# ).startswith("##"):
|
| 51 |
+
# offset_mapping[i][j][1] = offset_mapping[i][k][1]
|
| 52 |
+
# j = k
|
| 53 |
+
|
| 54 |
+
for i, offset_mapping in enumerate(offsets_mapping):
|
| 55 |
+
labels.append([])
|
| 56 |
+
|
| 57 |
+
spans = eval(examples["spans"][i])
|
| 58 |
+
Bs = eval(examples["Bs"][i])
|
| 59 |
+
Is = eval(examples["Is"][i])
|
| 60 |
+
|
| 61 |
+
example_spans.append(spans)
|
| 62 |
+
# cls_label = 2 ## DUMMY LABEL
|
| 63 |
+
cls_label = 3 ## DUMMY LABEL
|
| 64 |
+
for j, offsets in enumerate(offset_mapping):
|
| 65 |
+
if tokenized_inputs["input_ids"][i][j] in [
|
| 66 |
+
self.tokenizer.sep_token_id,
|
| 67 |
+
self.tokenizer.pad_token_id,
|
| 68 |
+
]:
|
| 69 |
+
tokenized_inputs["attention_mask"][i][j] = 0
|
| 70 |
+
|
| 71 |
+
if tokenized_inputs["input_ids"][i][j] == self.tokenizer.cls_token_id:
|
| 72 |
+
labels[-1].append(cls_label)
|
| 73 |
+
prediction_mask[i][j] = 1
|
| 74 |
+
|
| 75 |
+
elif offsets[0] == offsets[1] and offsets[0] == 0:
|
| 76 |
+
# labels[-1].append(2) ## DUMMY
|
| 77 |
+
labels[-1].append(cls_label) ## DUMMY
|
| 78 |
+
|
| 79 |
+
else:
|
| 80 |
+
# toxic_offsets = [x in spans for x in range(offsets[0], offsets[1])]
|
| 81 |
+
# ## If any part of the the token is in span, mark it as Toxic
|
| 82 |
+
# if (
|
| 83 |
+
# len(toxic_offsets) > 0
|
| 84 |
+
# and sum(toxic_offsets) / len(toxic_offsets) > 0.0
|
| 85 |
+
# ):
|
| 86 |
+
# labels[-1].append(1)
|
| 87 |
+
# else:
|
| 88 |
+
# labels[-1].append(0)
|
| 89 |
+
# prediction_mask[i][j] = 1
|
| 90 |
+
|
| 91 |
+
b_off = [x in Bs for x in range(offsets[0], offsets[1])]
|
| 92 |
+
b_off = sum(b_off)
|
| 93 |
+
i_off = [x in Is for x in range(offsets[0], offsets[1])]
|
| 94 |
+
i_off = sum(i_off)
|
| 95 |
+
# if len(b_off) == len(i_off) and len(i_off) == 0:
|
| 96 |
+
if b_off == 0 and i_off == 0:
|
| 97 |
+
labels[-1].append(0)
|
| 98 |
+
# elif len(b_off) >= len(i_off) == 1:
|
| 99 |
+
elif b_off >= i_off:
|
| 100 |
+
labels[-1].append(1)
|
| 101 |
+
# print(b_off)
|
| 102 |
+
# print(i_off)
|
| 103 |
+
# print(j)
|
| 104 |
+
else:
|
| 105 |
+
labels[-1].append(2)
|
| 106 |
+
|
| 107 |
+
tokenized_inputs["labels"] = labels
|
| 108 |
+
tokenized_inputs["prediction_mask"] = prediction_mask
|
| 109 |
+
return tokenized_inputs
|
| 110 |
+
|
| 111 |
+
def tokenize_for_test(self, examples):
|
| 112 |
+
tokenized_inputs = self.tokenizer(
|
| 113 |
+
examples["text"], **self.config.tokenizer_params
|
| 114 |
+
)
|
| 115 |
+
prediction_mask = np.zeros_like(np.array(tokenized_inputs["input_ids"]))
|
| 116 |
+
labels = np.zeros_like(np.array(tokenized_inputs["input_ids"]))
|
| 117 |
+
|
| 118 |
+
offsets_mapping = tokenized_inputs["offset_mapping"]
|
| 119 |
+
|
| 120 |
+
for i, offset_mapping in enumerate(offsets_mapping):
|
| 121 |
+
for j, offsets in enumerate(offset_mapping):
|
| 122 |
+
if tokenized_inputs["input_ids"][i][j] in [
|
| 123 |
+
self.tokenizer.sep_token_id,
|
| 124 |
+
self.tokenizer.pad_token_id,
|
| 125 |
+
]:
|
| 126 |
+
tokenized_inputs["attention_mask"][i][j] = 0
|
| 127 |
+
else:
|
| 128 |
+
prediction_mask[i][j] = 1
|
| 129 |
+
|
| 130 |
+
tokenized_inputs["prediction_mask"] = prediction_mask
|
| 131 |
+
tokenized_inputs["labels"] = labels
|
| 132 |
+
return tokenized_inputs
|
src/datasets/toxic_spans_crf_tokens.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.utils.mapper import configmapper
|
| 2 |
+
from transformers import AutoTokenizer
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@configmapper.map("datasets", "toxic_spans_crf_tokens")
|
| 8 |
+
class ToxicSpansCRFTokenDataset:
|
| 9 |
+
def __init__(self, config):
|
| 10 |
+
self.config = config
|
| 11 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 12 |
+
self.config.model_checkpoint_name
|
| 13 |
+
)
|
| 14 |
+
self.dataset = load_dataset("csv", data_files=dict(self.config.train_files))
|
| 15 |
+
self.test_dataset = load_dataset("csv", data_files=dict(self.config.eval_files))
|
| 16 |
+
|
| 17 |
+
self.tokenized_inputs = self.dataset.map(
|
| 18 |
+
self.tokenize_and_align_labels_for_train, batched=True
|
| 19 |
+
)
|
| 20 |
+
self.test_tokenized_inputs = self.test_dataset.map(
|
| 21 |
+
self.tokenize_for_test, batched=True
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def tokenize_and_align_labels_for_train(self, examples):
|
| 25 |
+
tokenized_inputs = self.tokenizer(
|
| 26 |
+
examples["text"], **self.config.tokenizer_params
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# tokenized_inputs["text"] = examples["text"]
|
| 30 |
+
example_spans = []
|
| 31 |
+
labels = []
|
| 32 |
+
prediction_mask = np.zeros_like(np.array(tokenized_inputs["input_ids"]))
|
| 33 |
+
offsets_mapping = tokenized_inputs["offset_mapping"]
|
| 34 |
+
|
| 35 |
+
## Wrong Code
|
| 36 |
+
# for i, offset_mapping in enumerate(offsets_mapping):
|
| 37 |
+
# j = 0
|
| 38 |
+
# while j < len(offset_mapping): # [tok1, tok2, tok3] [(0,5),(1,4),(5,7)]
|
| 39 |
+
# if tokenized_inputs["input_ids"][i][j] in [
|
| 40 |
+
# self.tokenizer.sep_token_id,
|
| 41 |
+
# self.tokenizer.pad_token_id,
|
| 42 |
+
# self.tokenizer.cls_token_id,
|
| 43 |
+
# ]:
|
| 44 |
+
# j = j + 1
|
| 45 |
+
# continue
|
| 46 |
+
# else:
|
| 47 |
+
# k = j + 1
|
| 48 |
+
# while self.tokenizer.convert_ids_to_tokens(
|
| 49 |
+
# tokenized_inputs["input_ids"][i][k]
|
| 50 |
+
# ).startswith("##"):
|
| 51 |
+
# offset_mapping[i][j][1] = offset_mapping[i][k][1]
|
| 52 |
+
# j = k
|
| 53 |
+
|
| 54 |
+
for i, offset_mapping in enumerate(offsets_mapping):
|
| 55 |
+
labels.append([])
|
| 56 |
+
|
| 57 |
+
spans = eval(examples["spans"][i])
|
| 58 |
+
example_spans.append(spans)
|
| 59 |
+
cls_label = 2 ## DUMMY LABEL
|
| 60 |
+
for j, offsets in enumerate(offset_mapping):
|
| 61 |
+
if tokenized_inputs["input_ids"][i][j] in [
|
| 62 |
+
self.tokenizer.sep_token_id,
|
| 63 |
+
self.tokenizer.pad_token_id,
|
| 64 |
+
]:
|
| 65 |
+
tokenized_inputs["attention_mask"][i][j] = 0
|
| 66 |
+
|
| 67 |
+
if tokenized_inputs["input_ids"][i][j] == self.tokenizer.cls_token_id:
|
| 68 |
+
labels[-1].append(cls_label)
|
| 69 |
+
prediction_mask[i][j] = 1
|
| 70 |
+
|
| 71 |
+
elif offsets[0] == offsets[1] and offsets[0] == 0:
|
| 72 |
+
labels[-1].append(2) ## DUMMY
|
| 73 |
+
|
| 74 |
+
else:
|
| 75 |
+
toxic_offsets = [x in spans for x in range(offsets[0], offsets[1])]
|
| 76 |
+
## If any part of the the token is in span, mark it as Toxic
|
| 77 |
+
if (
|
| 78 |
+
len(toxic_offsets) > 0
|
| 79 |
+
and sum(toxic_offsets) / len(toxic_offsets) > 0.0
|
| 80 |
+
):
|
| 81 |
+
labels[-1].append(1)
|
| 82 |
+
else:
|
| 83 |
+
labels[-1].append(0)
|
| 84 |
+
prediction_mask[i][j] = 1
|
| 85 |
+
|
| 86 |
+
tokenized_inputs["labels"] = labels
|
| 87 |
+
tokenized_inputs["prediction_mask"] = prediction_mask
|
| 88 |
+
return tokenized_inputs
|
| 89 |
+
|
| 90 |
+
def tokenize_for_test(self, examples):
|
| 91 |
+
tokenized_inputs = self.tokenizer(
|
| 92 |
+
examples["text"], **self.config.tokenizer_params
|
| 93 |
+
)
|
| 94 |
+
prediction_mask = np.zeros_like(np.array(tokenized_inputs["input_ids"]))
|
| 95 |
+
labels = np.zeros_like(np.array(tokenized_inputs["input_ids"]))
|
| 96 |
+
|
| 97 |
+
offsets_mapping = tokenized_inputs["offset_mapping"]
|
| 98 |
+
|
| 99 |
+
for i, offset_mapping in enumerate(offsets_mapping):
|
| 100 |
+
for j, offsets in enumerate(offset_mapping):
|
| 101 |
+
if tokenized_inputs["input_ids"][i][j] in [
|
| 102 |
+
self.tokenizer.sep_token_id,
|
| 103 |
+
self.tokenizer.pad_token_id,
|
| 104 |
+
]:
|
| 105 |
+
tokenized_inputs["attention_mask"][i][j] = 0
|
| 106 |
+
else:
|
| 107 |
+
prediction_mask[i][j] = 1
|
| 108 |
+
|
| 109 |
+
tokenized_inputs["prediction_mask"] = prediction_mask
|
| 110 |
+
tokenized_inputs["labels"] = labels
|
| 111 |
+
return tokenized_inputs
|
src/datasets/toxic_spans_multi_spans.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.utils.mapper import configmapper
|
| 2 |
+
from transformers import AutoTokenizer
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from datasets import load_dataset, Dataset
|
| 5 |
+
from evaluation.fix_spans import _contiguous_ranges
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@configmapper.map("datasets", "toxic_spans_multi_spans")
|
| 9 |
+
class ToxicSpansMultiSpansDataset:
|
| 10 |
+
def __init__(self, config):
|
| 11 |
+
self.config = config
|
| 12 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 13 |
+
self.config.model_checkpoint_name
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
self.dataset = load_dataset("csv", data_files=dict(self.config.train_files))
|
| 17 |
+
self.test_dataset = load_dataset("csv", data_files=dict(self.config.eval_files))
|
| 18 |
+
|
| 19 |
+
temp_key_train = list(self.dataset.keys())[0]
|
| 20 |
+
self.intermediate_dataset = self.dataset.map(
|
| 21 |
+
self.create_train_features,
|
| 22 |
+
batched=True,
|
| 23 |
+
batch_size=1000000, ##Unusually Large Batch Size ## Needed For Correct ID mapping
|
| 24 |
+
remove_columns=self.dataset[temp_key_train].column_names,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
temp_key_test = list(self.test_dataset.keys())[0]
|
| 28 |
+
self.intermediate_test_dataset = self.test_dataset.map(
|
| 29 |
+
self.create_test_features,
|
| 30 |
+
batched=True,
|
| 31 |
+
batch_size=1000000, ##Unusually Large Batch Size ## Needed For Correct ID mapping
|
| 32 |
+
remove_columns=self.test_dataset[temp_key_test].column_names,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
self.tokenized_inputs = self.intermediate_dataset.map(
|
| 36 |
+
self.prepare_train_features,
|
| 37 |
+
batched=True,
|
| 38 |
+
remove_columns=self.intermediate_dataset[temp_key_train].column_names,
|
| 39 |
+
)
|
| 40 |
+
self.test_tokenized_inputs = self.intermediate_test_dataset.map(
|
| 41 |
+
self.prepare_test_features,
|
| 42 |
+
batched=True,
|
| 43 |
+
remove_columns=self.intermediate_test_dataset[temp_key_test].column_names,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def create_train_features(self, examples):
|
| 47 |
+
features = {
|
| 48 |
+
"context": [],
|
| 49 |
+
"id": [],
|
| 50 |
+
"question": [],
|
| 51 |
+
"title": [],
|
| 52 |
+
"start_positions": [],
|
| 53 |
+
"end_positions": [],
|
| 54 |
+
}
|
| 55 |
+
id = 0
|
| 56 |
+
# print(examples)
|
| 57 |
+
for row_number in range(len(examples["text"])):
|
| 58 |
+
context = examples["text"][row_number]
|
| 59 |
+
question = "offense"
|
| 60 |
+
title = context.split(" ")[0]
|
| 61 |
+
start_positions = []
|
| 62 |
+
end_positions = []
|
| 63 |
+
span = eval(examples["spans"][row_number])
|
| 64 |
+
contiguous_spans = _contiguous_ranges(span)
|
| 65 |
+
for lst in contiguous_spans:
|
| 66 |
+
lst = list(lst)
|
| 67 |
+
dict_to_write = {}
|
| 68 |
+
|
| 69 |
+
start_positions.append(lst[0])
|
| 70 |
+
end_positions.append(lst[1])
|
| 71 |
+
|
| 72 |
+
features["context"].append(context)
|
| 73 |
+
features["id"].append(str(id))
|
| 74 |
+
features["question"].append(question)
|
| 75 |
+
features["title"].append(title)
|
| 76 |
+
features["start_positions"].append(start_positions)
|
| 77 |
+
features["end_positions"].append(end_positions)
|
| 78 |
+
id += 1
|
| 79 |
+
|
| 80 |
+
return features
|
| 81 |
+
|
| 82 |
+
def create_test_features(self, examples):
|
| 83 |
+
features = {"context": [], "id": [], "question": [], "title": []}
|
| 84 |
+
id = 0
|
| 85 |
+
for row_number in range(len(examples["text"])):
|
| 86 |
+
context = examples["text"][row_number]
|
| 87 |
+
question = "offense"
|
| 88 |
+
title = context.split(" ")[0]
|
| 89 |
+
features["context"].append(context)
|
| 90 |
+
features["id"].append(str(id))
|
| 91 |
+
features["question"].append(question)
|
| 92 |
+
features["title"].append(title)
|
| 93 |
+
id += 1
|
| 94 |
+
return features
|
| 95 |
+
|
| 96 |
+
def prepare_train_features(self, examples):
|
| 97 |
+
"""Generate tokenized features from examples.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
examples (dict): The examples to be tokenized.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
transformers.tokenization_utils_base.BatchEncoding:
|
| 104 |
+
The tokenized features/examples after processing.
|
| 105 |
+
"""
|
| 106 |
+
# Tokenize our examples with truncation and padding, but keep the
|
| 107 |
+
# overflows using a stride. This results in one example possible
|
| 108 |
+
# giving several features when a context is long, each of those
|
| 109 |
+
# features having a context that overlaps a bit the context
|
| 110 |
+
# of the previous feature.
|
| 111 |
+
pad_on_right = self.tokenizer.padding_side == "right"
|
| 112 |
+
print("### Batch Tokenizing Examples ###")
|
| 113 |
+
tokenized_examples = self.tokenizer(
|
| 114 |
+
examples["question" if pad_on_right else "context"],
|
| 115 |
+
examples["context" if pad_on_right else "question"],
|
| 116 |
+
**dict(self.config.tokenizer_params),
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Since one example might give us several features if it has
|
| 120 |
+
# a long context, we need a map from a feature to
|
| 121 |
+
# its corresponding example. This key gives us just that.
|
| 122 |
+
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
| 123 |
+
# The offset mappings will give us a map from token to
|
| 124 |
+
# character position in the original context. This will
|
| 125 |
+
# help us compute the start_positions and end_positions.
|
| 126 |
+
offset_mapping = tokenized_examples.pop("offset_mapping")
|
| 127 |
+
|
| 128 |
+
# Let's label those examples!
|
| 129 |
+
tokenized_examples["start_positions"] = []
|
| 130 |
+
tokenized_examples["end_positions"] = []
|
| 131 |
+
|
| 132 |
+
for i, offsets in enumerate(offset_mapping):
|
| 133 |
+
# We will label impossible answers with the index of the CLS token.
|
| 134 |
+
input_ids = tokenized_examples["input_ids"][i]
|
| 135 |
+
|
| 136 |
+
# Grab the sequence corresponding to that example
|
| 137 |
+
# (to know what is the context and what is the question).
|
| 138 |
+
sequence_ids = tokenized_examples.sequence_ids(i)
|
| 139 |
+
|
| 140 |
+
# One example can give several spans, this is the index of
|
| 141 |
+
# the example containing this span of text.
|
| 142 |
+
sample_index = sample_mapping[i]
|
| 143 |
+
start_positions = examples["start_positions"][sample_index]
|
| 144 |
+
end_positions = examples["end_positions"][sample_index]
|
| 145 |
+
|
| 146 |
+
start_positions_token_wise = [0 for x in range(len(input_ids))]
|
| 147 |
+
end_positions_token_wise = [0 for x in range(len(input_ids))]
|
| 148 |
+
# If no answers are given, set the cls_index as answer.
|
| 149 |
+
if len(start_positions) != 0:
|
| 150 |
+
for position in range(len(start_positions)):
|
| 151 |
+
start_char = start_positions[position]
|
| 152 |
+
end_char = end_positions[position] + 1
|
| 153 |
+
|
| 154 |
+
# Start token index of the current span in the text.
|
| 155 |
+
token_start_index = 0
|
| 156 |
+
while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
|
| 157 |
+
token_start_index += 1
|
| 158 |
+
|
| 159 |
+
# End token index of the current span in the text.
|
| 160 |
+
token_end_index = len(input_ids) - 1
|
| 161 |
+
while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
|
| 162 |
+
token_end_index -= 1
|
| 163 |
+
|
| 164 |
+
# Detect if the answer is out of the span (in which case we continue).
|
| 165 |
+
if not (
|
| 166 |
+
offsets[token_start_index][0] <= start_char
|
| 167 |
+
and offsets[token_end_index][1] >= end_char
|
| 168 |
+
):
|
| 169 |
+
continue
|
| 170 |
+
else:
|
| 171 |
+
# Otherwise move the token_start_index and token_end_index to the two ends of the answer.
|
| 172 |
+
# Note: we could go after the last offset if the answer is the last word (edge case).
|
| 173 |
+
while (
|
| 174 |
+
token_start_index < len(offsets)
|
| 175 |
+
and offsets[token_start_index][0] <= start_char
|
| 176 |
+
):
|
| 177 |
+
token_start_index += 1
|
| 178 |
+
start_positions_token_wise[token_start_index - 1] = 1
|
| 179 |
+
while offsets[token_end_index][1] >= end_char:
|
| 180 |
+
token_end_index -= 1
|
| 181 |
+
end_positions_token_wise[token_end_index + 1] = 1
|
| 182 |
+
tokenized_examples["start_positions"].append(start_positions_token_wise)
|
| 183 |
+
tokenized_examples["end_positions"].append(start_positions_token_wise)
|
| 184 |
+
return tokenized_examples
|
| 185 |
+
|
| 186 |
+
def prepare_test_features(self, examples):
|
| 187 |
+
|
| 188 |
+
"""Generate tokenized validation features from examples.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
examples (dict): The validation examples to be tokenized.
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
transformers.tokenization_utils_base.BatchEncoding:
|
| 195 |
+
The tokenized features/examples for validation set after processing.
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
# Tokenize our examples with truncation and maybe
|
| 199 |
+
# padding, but keep the overflows using a stride.
|
| 200 |
+
# This results in one example possible giving several features
|
| 201 |
+
# when a context is long, each of those features having a
|
| 202 |
+
# context that overlaps a bit the context of the previous feature.
|
| 203 |
+
print("### Tokenizing Validation Examples")
|
| 204 |
+
pad_on_right = self.tokenizer.padding_side == "right"
|
| 205 |
+
tokenized_examples = self.tokenizer(
|
| 206 |
+
examples["question" if pad_on_right else "context"],
|
| 207 |
+
examples["context" if pad_on_right else "question"],
|
| 208 |
+
**dict(self.config.tokenizer_params),
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# Since one example might give us several features if it has a long context,
|
| 212 |
+
# we need a map from a feature to its corresponding example. This key gives us just that.
|
| 213 |
+
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
| 214 |
+
|
| 215 |
+
# We keep the example_id that gave us this feature and we will store the offset mappings.
|
| 216 |
+
tokenized_examples["example_id"] = []
|
| 217 |
+
|
| 218 |
+
for i in range(len(tokenized_examples["input_ids"])):
|
| 219 |
+
# Grab the sequence corresponding to that example
|
| 220 |
+
# (to know what is the context and what is the question).
|
| 221 |
+
sequence_ids = tokenized_examples.sequence_ids(i)
|
| 222 |
+
context_index = 1 if pad_on_right else 0
|
| 223 |
+
|
| 224 |
+
# One example can give several spans,
|
| 225 |
+
# this is the index of the example containing this span of text.
|
| 226 |
+
sample_index = sample_mapping[i]
|
| 227 |
+
tokenized_examples["example_id"].append(str(examples["id"][sample_index]))
|
| 228 |
+
|
| 229 |
+
# Set to None the offset_mapping that are not part
|
| 230 |
+
# of the context so it's easy to determine if a token
|
| 231 |
+
# position is part of the context or not.
|
| 232 |
+
tokenized_examples["offset_mapping"][i] = [
|
| 233 |
+
(o if sequence_ids[k] == context_index else None)
|
| 234 |
+
for k, o in enumerate(tokenized_examples["offset_mapping"][i])
|
| 235 |
+
]
|
| 236 |
+
|
| 237 |
+
return tokenized_examples
|
src/datasets/toxic_spans_spans.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.utils.mapper import configmapper
|
| 2 |
+
from transformers import AutoTokenizer
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from datasets import load_dataset, Dataset
|
| 5 |
+
from evaluation.fix_spans import _contiguous_ranges
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@configmapper.map("datasets", "toxic_spans_spans")
|
| 9 |
+
class ToxicSpansSpansDataset:
|
| 10 |
+
def __init__(self, config):
|
| 11 |
+
# print("### ToxicSpansSpansDataset ###"); exit()
|
| 12 |
+
self.config = config
|
| 13 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 14 |
+
self.config.model_checkpoint_name
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
self.dataset = load_dataset("csv", data_files=dict(self.config.train_files))
|
| 18 |
+
self.test_dataset = load_dataset("csv", data_files=dict(self.config.eval_files))
|
| 19 |
+
|
| 20 |
+
temp_key_train = list(self.dataset.keys())[0]
|
| 21 |
+
self.intermediate_dataset = self.dataset.map(
|
| 22 |
+
self.create_train_features,
|
| 23 |
+
batched=True,
|
| 24 |
+
batch_size=1000000, ##Unusually Large Batch Size ## Needed For Correct ID mapping
|
| 25 |
+
remove_columns=self.dataset[temp_key_train].column_names,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
temp_key_test = list(self.test_dataset.keys())[0]
|
| 29 |
+
self.intermediate_test_dataset = self.test_dataset.map(
|
| 30 |
+
self.create_test_features,
|
| 31 |
+
batched=True,
|
| 32 |
+
batch_size=1000000, ##Unusually Large Batch Size ## Needed For Correct ID mapping
|
| 33 |
+
remove_columns=self.test_dataset[temp_key_test].column_names,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
self.tokenized_inputs = self.intermediate_dataset.map(
|
| 37 |
+
self.prepare_train_features,
|
| 38 |
+
batched=True,
|
| 39 |
+
remove_columns=self.intermediate_dataset[temp_key_train].column_names,
|
| 40 |
+
)
|
| 41 |
+
self.test_tokenized_inputs = self.intermediate_test_dataset.map(
|
| 42 |
+
self.prepare_test_features,
|
| 43 |
+
batched=True,
|
| 44 |
+
remove_columns=self.intermediate_test_dataset[temp_key_test].column_names,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def create_train_features(self, examples):
|
| 48 |
+
features = {"context": [], "id": [], "question": [], "title": []}
|
| 49 |
+
id = 0
|
| 50 |
+
# print(examples)
|
| 51 |
+
for row_number in range(len(examples["text"])):
|
| 52 |
+
context = examples["text"][row_number]
|
| 53 |
+
# question = "offense"
|
| 54 |
+
question = "ভুল"
|
| 55 |
+
title = context.split(" ")[0]
|
| 56 |
+
span = eval(examples["spans"][row_number])
|
| 57 |
+
contiguous_spans = _contiguous_ranges(span)
|
| 58 |
+
for lst in contiguous_spans:
|
| 59 |
+
lst = list(lst)
|
| 60 |
+
dict_to_write = {}
|
| 61 |
+
|
| 62 |
+
dict_to_write["answer_start"] = [lst[0]]
|
| 63 |
+
dict_to_write["text"] = [context[lst[0] : lst[-1] + 1]]
|
| 64 |
+
# print(dict_to_write)
|
| 65 |
+
if "answers" in features.keys():
|
| 66 |
+
features["answers"].append(dict_to_write)
|
| 67 |
+
else:
|
| 68 |
+
features["answers"] = [
|
| 69 |
+
dict_to_write,
|
| 70 |
+
]
|
| 71 |
+
features["context"].append(context)
|
| 72 |
+
features["id"].append(str(id))
|
| 73 |
+
features["question"].append(question)
|
| 74 |
+
features["title"].append(title)
|
| 75 |
+
id += 1
|
| 76 |
+
|
| 77 |
+
return features
|
| 78 |
+
|
| 79 |
+
def create_test_features(self, examples):
|
| 80 |
+
features = {"context": [], "id": [], "question": [], "title": []}
|
| 81 |
+
id = 0
|
| 82 |
+
for row_number in range(len(examples["text"])):
|
| 83 |
+
context = examples["text"][row_number]
|
| 84 |
+
# question = "offense"
|
| 85 |
+
question = "ভুল"
|
| 86 |
+
title = context.split(" ")[0]
|
| 87 |
+
features["context"].append(context)
|
| 88 |
+
features["id"].append(str(id))
|
| 89 |
+
features["question"].append(question)
|
| 90 |
+
features["title"].append(title)
|
| 91 |
+
id += 1
|
| 92 |
+
return features
|
| 93 |
+
|
| 94 |
+
def prepare_train_features(self, examples):
|
| 95 |
+
"""Generate tokenized features from examples.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
examples (dict): The examples to be tokenized.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
transformers.tokenization_utils_base.BatchEncoding:
|
| 102 |
+
The tokenized features/examples after processing.
|
| 103 |
+
"""
|
| 104 |
+
# Tokenize our examples with truncation and padding, but keep the
|
| 105 |
+
# overflows using a stride. This results in one example possible
|
| 106 |
+
# giving several features when a context is long, each of those
|
| 107 |
+
# features having a context that overlaps a bit the context
|
| 108 |
+
# of the previous feature.
|
| 109 |
+
pad_on_right = self.tokenizer.padding_side == "right"
|
| 110 |
+
print("### Batch Tokenizing Examples ###")
|
| 111 |
+
tokenized_examples = self.tokenizer(
|
| 112 |
+
examples["question" if pad_on_right else "context"],
|
| 113 |
+
examples["context" if pad_on_right else "question"],
|
| 114 |
+
**dict(self.config.tokenizer_params),
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Since one example might give us several features if it has
|
| 118 |
+
# a long context, we need a map from a feature to
|
| 119 |
+
# its corresponding example. This key gives us just that.
|
| 120 |
+
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
| 121 |
+
# The offset mappings will give us a map from token to
|
| 122 |
+
# character position in the original context. This will
|
| 123 |
+
# help us compute the start_positions and end_positions.
|
| 124 |
+
offset_mapping = tokenized_examples.pop("offset_mapping")
|
| 125 |
+
|
| 126 |
+
# Let's label those examples!
|
| 127 |
+
tokenized_examples["start_positions"] = []
|
| 128 |
+
tokenized_examples["end_positions"] = []
|
| 129 |
+
|
| 130 |
+
for i, offsets in enumerate(offset_mapping):
|
| 131 |
+
# We will label impossible answers with the index of the CLS token.
|
| 132 |
+
input_ids = tokenized_examples["input_ids"][i]
|
| 133 |
+
cls_index = input_ids.index(self.tokenizer.cls_token_id)
|
| 134 |
+
|
| 135 |
+
# Grab the sequence corresponding to that example
|
| 136 |
+
# (to know what is the context and what is the question).
|
| 137 |
+
sequence_ids = tokenized_examples.sequence_ids(i)
|
| 138 |
+
|
| 139 |
+
# One example can give several spans, this is the index of
|
| 140 |
+
# the example containing this span of text.
|
| 141 |
+
sample_index = sample_mapping[i]
|
| 142 |
+
answers = examples["answers"][sample_index]
|
| 143 |
+
# If no answers are given, set the cls_index as answer.
|
| 144 |
+
if len(answers["answer_start"]) == 0:
|
| 145 |
+
tokenized_examples["start_positions"].append(cls_index)
|
| 146 |
+
tokenized_examples["end_positions"].append(cls_index)
|
| 147 |
+
else:
|
| 148 |
+
# Start/end character index of the answer in the text.
|
| 149 |
+
start_char = answers["answer_start"][0]
|
| 150 |
+
end_char = start_char + len(answers["text"][0])
|
| 151 |
+
|
| 152 |
+
# Start token index of the current span in the text.
|
| 153 |
+
token_start_index = 0
|
| 154 |
+
while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
|
| 155 |
+
token_start_index += 1
|
| 156 |
+
|
| 157 |
+
# End token index of the current span in the text.
|
| 158 |
+
token_end_index = len(input_ids) - 1
|
| 159 |
+
while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
|
| 160 |
+
token_end_index -= 1
|
| 161 |
+
|
| 162 |
+
# Detect if the answer is out of the span
|
| 163 |
+
# (in which case this feature is labeled with the CLS index).
|
| 164 |
+
if not (
|
| 165 |
+
offsets[token_start_index][0] <= start_char
|
| 166 |
+
and offsets[token_end_index][1] >= end_char
|
| 167 |
+
):
|
| 168 |
+
tokenized_examples["start_positions"].append(cls_index)
|
| 169 |
+
tokenized_examples["end_positions"].append(cls_index)
|
| 170 |
+
else:
|
| 171 |
+
# Otherwise move the token_start_index and
|
| 172 |
+
# stoken_end_index to the two ends of the answer.
|
| 173 |
+
# Note: we could go after the last offset
|
| 174 |
+
# if the answer is the last word (edge case).
|
| 175 |
+
while (
|
| 176 |
+
token_start_index < len(offsets)
|
| 177 |
+
and offsets[token_start_index][0] <= start_char
|
| 178 |
+
):
|
| 179 |
+
token_start_index += 1
|
| 180 |
+
tokenized_examples["start_positions"].append(token_start_index - 1)
|
| 181 |
+
while offsets[token_end_index][1] >= end_char:
|
| 182 |
+
token_end_index -= 1
|
| 183 |
+
tokenized_examples["end_positions"].append(token_end_index + 1)
|
| 184 |
+
|
| 185 |
+
return tokenized_examples
|
| 186 |
+
|
| 187 |
+
def prepare_test_features(self, examples):
|
| 188 |
+
|
| 189 |
+
"""Generate tokenized validation features from examples.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
examples (dict): The validation examples to be tokenized.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
transformers.tokenization_utils_base.BatchEncoding:
|
| 196 |
+
The tokenized features/examples for validation set after processing.
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
# Tokenize our examples with truncation and maybe
|
| 200 |
+
# padding, but keep the overflows using a stride.
|
| 201 |
+
# This results in one example possible giving several features
|
| 202 |
+
# when a context is long, each of those features having a
|
| 203 |
+
# context that overlaps a bit the context of the previous feature.
|
| 204 |
+
print("### Tokenizing Validation Examples")
|
| 205 |
+
pad_on_right = self.tokenizer.padding_side == "right"
|
| 206 |
+
tokenized_examples = self.tokenizer(
|
| 207 |
+
examples["question" if pad_on_right else "context"],
|
| 208 |
+
examples["context" if pad_on_right else "question"],
|
| 209 |
+
**dict(self.config.tokenizer_params),
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# Since one example might give us several features if it has a long context,
|
| 213 |
+
# we need a map from a feature to its corresponding example. This key gives us just that.
|
| 214 |
+
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
| 215 |
+
|
| 216 |
+
# We keep the example_id that gave us this feature and we will store the offset mappings.
|
| 217 |
+
tokenized_examples["example_id"] = []
|
| 218 |
+
|
| 219 |
+
for i in range(len(tokenized_examples["input_ids"])):
|
| 220 |
+
# Grab the sequence corresponding to that example
|
| 221 |
+
# (to know what is the context and what is the question).
|
| 222 |
+
sequence_ids = tokenized_examples.sequence_ids(i)
|
| 223 |
+
context_index = 1 if pad_on_right else 0
|
| 224 |
+
|
| 225 |
+
# One example can give several spans,
|
| 226 |
+
# this is the index of the example containing this span of text.
|
| 227 |
+
sample_index = sample_mapping[i]
|
| 228 |
+
tokenized_examples["example_id"].append(str(examples["id"][sample_index]))
|
| 229 |
+
|
| 230 |
+
# Set to None the offset_mapping that are not part
|
| 231 |
+
# of the context so it's easy to determine if a token
|
| 232 |
+
# position is part of the context or not.
|
| 233 |
+
tokenized_examples["offset_mapping"][i] = [
|
| 234 |
+
(o if sequence_ids[k] == context_index else None)
|
| 235 |
+
for k, o in enumerate(tokenized_examples["offset_mapping"][i])
|
| 236 |
+
]
|
| 237 |
+
|
| 238 |
+
return tokenized_examples
|
src/datasets/toxic_spans_tokens.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.utils.mapper import configmapper
|
| 2 |
+
from transformers import AutoTokenizer
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
|
| 5 |
+
# import pdb
|
| 6 |
+
|
| 7 |
+
@configmapper.map("datasets", "toxic_spans_tokens")
|
| 8 |
+
class ToxicSpansTokenDataset:
|
| 9 |
+
def __init__(self, config):
|
| 10 |
+
# print("### ToxicSpansTokenDataset ###"); exit()
|
| 11 |
+
self.config = config
|
| 12 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 13 |
+
self.config.model_checkpoint_name
|
| 14 |
+
)
|
| 15 |
+
# if self.config.model_checkpoint_name == "sberbank-ai/mGPT":
|
| 16 |
+
# self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
| 17 |
+
self.dataset = load_dataset("csv", data_files=dict(self.config.train_files))
|
| 18 |
+
self.test_dataset = load_dataset("csv", data_files=dict(self.config.eval_files))
|
| 19 |
+
|
| 20 |
+
self.tokenized_inputs = self.dataset.map(
|
| 21 |
+
self.tokenize_and_align_labels_for_train, batched=True
|
| 22 |
+
)
|
| 23 |
+
self.test_tokenized_inputs = self.test_dataset.map(
|
| 24 |
+
self.tokenize_for_test, batched=True
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def tokenize_and_align_labels_for_train(self, examples):
|
| 28 |
+
|
| 29 |
+
tokenized_inputs = self.tokenizer(
|
| 30 |
+
examples["text"], **self.config.tokenizer_params
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# tokenized_inputs["text"] = examples["text"]
|
| 34 |
+
example_spans = []
|
| 35 |
+
labels = []
|
| 36 |
+
|
| 37 |
+
offsets_mapping = tokenized_inputs["offset_mapping"]
|
| 38 |
+
# pdb.set_trace()
|
| 39 |
+
for i, offset_mapping in enumerate(offsets_mapping):
|
| 40 |
+
labels.append([])
|
| 41 |
+
|
| 42 |
+
spans = eval(examples["spans"][i])
|
| 43 |
+
example_spans.append(spans)
|
| 44 |
+
if self.config.label_cls:
|
| 45 |
+
cls_label = (
|
| 46 |
+
1
|
| 47 |
+
if (
|
| 48 |
+
len(examples["text"][i]) > 0
|
| 49 |
+
and len(spans) / len(examples["text"][i])
|
| 50 |
+
> self.config.cls_threshold
|
| 51 |
+
)
|
| 52 |
+
else 0
|
| 53 |
+
) ## Make class label based on threshold
|
| 54 |
+
else:
|
| 55 |
+
cls_label = -100
|
| 56 |
+
for j, offsets in enumerate(offset_mapping):
|
| 57 |
+
if tokenized_inputs["input_ids"][i][j] == self.tokenizer.cls_token_id:
|
| 58 |
+
labels[-1].append(cls_label)
|
| 59 |
+
elif offsets[0] == offsets[1] and offsets[0] == 0: # All zero
|
| 60 |
+
labels[-1].append(-100) ## SPECIAL TOKEN
|
| 61 |
+
else:
|
| 62 |
+
toxic_offsets = [x in spans for x in range(offsets[0], offsets[1])]
|
| 63 |
+
## If any part of the the token is in span, mark it as Toxic
|
| 64 |
+
if (
|
| 65 |
+
len(toxic_offsets) > 0
|
| 66 |
+
and sum(toxic_offsets) / len(toxic_offsets)
|
| 67 |
+
> self.config.token_threshold
|
| 68 |
+
):
|
| 69 |
+
labels[-1].append(1)
|
| 70 |
+
else:
|
| 71 |
+
labels[-1].append(0)
|
| 72 |
+
|
| 73 |
+
tokenized_inputs["labels"] = labels
|
| 74 |
+
# print("tokenized_inputs", tokenized_inputs); exit()
|
| 75 |
+
return tokenized_inputs
|
| 76 |
+
|
| 77 |
+
def tokenize_for_test(self, examples):
|
| 78 |
+
tokenized_inputs = self.tokenizer(
|
| 79 |
+
examples["text"], **self.config.tokenizer_params
|
| 80 |
+
)
|
| 81 |
+
return tokenized_inputs
|
src/datasets/toxic_spans_tokens_3cls.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.utils.mapper import configmapper
|
| 2 |
+
from transformers import AutoTokenizer
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
|
| 5 |
+
import pdb
|
| 6 |
+
|
| 7 |
+
@configmapper.map("datasets", "toxic_spans_tokens_3cls")
|
| 8 |
+
class ToxicSpansToken3CLSDataset:
|
| 9 |
+
def __init__(self, config):
|
| 10 |
+
# print("### ToxicSpansTokenDataset ###"); exit()
|
| 11 |
+
self.config = config
|
| 12 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 13 |
+
self.config.model_checkpoint_name
|
| 14 |
+
)
|
| 15 |
+
# if self.config.model_checkpoint_name == "sberbank-ai/mGPT":
|
| 16 |
+
# self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
| 17 |
+
self.dataset = load_dataset("csv", data_files=dict(self.config.train_files))
|
| 18 |
+
self.test_dataset = load_dataset("csv", data_files=dict(self.config.eval_files))
|
| 19 |
+
|
| 20 |
+
self.tokenized_inputs = self.dataset.map(
|
| 21 |
+
self.tokenize_and_align_labels_for_train, batched=True
|
| 22 |
+
)
|
| 23 |
+
self.test_tokenized_inputs = self.test_dataset.map(
|
| 24 |
+
self.tokenize_for_test, batched=True
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def tokenize_and_align_labels_for_train(self, examples):
|
| 28 |
+
|
| 29 |
+
tokenized_inputs = self.tokenizer(
|
| 30 |
+
examples["text"], **self.config.tokenizer_params
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# tokenized_inputs["text"] = examples["text"]
|
| 34 |
+
example_spans = []
|
| 35 |
+
labels = []
|
| 36 |
+
|
| 37 |
+
offsets_mapping = tokenized_inputs["offset_mapping"]
|
| 38 |
+
# pdb.set_trace()
|
| 39 |
+
for i, offset_mapping in enumerate(offsets_mapping):
|
| 40 |
+
labels.append([])
|
| 41 |
+
|
| 42 |
+
spans = eval(examples["spans"][i])
|
| 43 |
+
Bs = eval(examples["Bs"][i])
|
| 44 |
+
Is = eval(examples["Is"][i])
|
| 45 |
+
example_spans.append(spans)
|
| 46 |
+
if self.config.label_cls:
|
| 47 |
+
cls_label = (
|
| 48 |
+
1
|
| 49 |
+
if (
|
| 50 |
+
len(examples["text"][i]) > 0
|
| 51 |
+
and len(spans) / len(examples["text"][i])
|
| 52 |
+
> self.config.cls_threshold
|
| 53 |
+
)
|
| 54 |
+
else 0
|
| 55 |
+
) ## Make class label based on threshold
|
| 56 |
+
else:
|
| 57 |
+
cls_label = -100
|
| 58 |
+
for j, offsets in enumerate(offset_mapping):
|
| 59 |
+
if tokenized_inputs["input_ids"][i][j] == self.tokenizer.cls_token_id:
|
| 60 |
+
labels[-1].append(cls_label)
|
| 61 |
+
elif offsets[0] == offsets[1] and offsets[0] == 0: # All zero
|
| 62 |
+
labels[-1].append(-100) ## SPECIAL TOKEN
|
| 63 |
+
else:
|
| 64 |
+
# toxic_offsets = [x in spans for x in range(offsets[0], offsets[1])]
|
| 65 |
+
## If any part of the the token is in span, mark it as Toxic
|
| 66 |
+
# if (
|
| 67 |
+
# len(toxic_offsets) > 0
|
| 68 |
+
# and sum(toxic_offsets) / len(toxic_offsets)
|
| 69 |
+
# > self.config.token_threshold
|
| 70 |
+
# ):
|
| 71 |
+
# labels[-1].append(1)
|
| 72 |
+
# else:
|
| 73 |
+
# labels[-1].append(0)
|
| 74 |
+
b_off = [x in Bs for x in range(offsets[0], offsets[1])]
|
| 75 |
+
b_off = sum(b_off)
|
| 76 |
+
i_off = [x in Is for x in range(offsets[0], offsets[1])]
|
| 77 |
+
i_off = sum(i_off)
|
| 78 |
+
# if len(b_off) == len(i_off) and len(i_off) == 0:
|
| 79 |
+
if b_off == 0 and i_off == 0:
|
| 80 |
+
labels[-1].append(0)
|
| 81 |
+
# elif len(b_off) >= len(i_off) == 1:
|
| 82 |
+
elif b_off >= i_off:
|
| 83 |
+
labels[-1].append(1)
|
| 84 |
+
# print(b_off)
|
| 85 |
+
# print(i_off)
|
| 86 |
+
# print(j)
|
| 87 |
+
else:
|
| 88 |
+
labels[-1].append(2)
|
| 89 |
+
|
| 90 |
+
# pdb.set_trace()
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
tokenized_inputs["labels"] = labels
|
| 95 |
+
# print("tokenized_inputs", tokenized_inputs); exit()
|
| 96 |
+
return tokenized_inputs
|
| 97 |
+
|
| 98 |
+
def tokenize_for_test(self, examples):
|
| 99 |
+
tokenized_inputs = self.tokenizer(
|
| 100 |
+
examples["text"], **self.config.tokenizer_params
|
| 101 |
+
)
|
| 102 |
+
return tokenized_inputs
|
src/datasets/toxic_spans_tokens_spans.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.utils.mapper import configmapper
|
| 2 |
+
from transformers import AutoTokenizer
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from datasets import load_dataset, Dataset
|
| 5 |
+
from evaluation.fix_spans import _contiguous_ranges
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@configmapper.map("datasets", "toxic_spans_tokens_spans")
|
| 9 |
+
class ToxicSpansTokensSpansDataset:
|
| 10 |
+
def __init__(self, config):
|
| 11 |
+
self.config = config
|
| 12 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 13 |
+
self.config.model_checkpoint_name
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
self.dataset = load_dataset("csv", data_files=dict(self.config.train_files))
|
| 17 |
+
self.test_dataset = load_dataset("csv", data_files=dict(self.config.eval_files))
|
| 18 |
+
|
| 19 |
+
temp_key_train = list(self.dataset.keys())[0]
|
| 20 |
+
self.intermediate_dataset = self.dataset.map(
|
| 21 |
+
self.create_train_features,
|
| 22 |
+
batched=True,
|
| 23 |
+
batch_size=1000000, ##Unusually Large Batch Size ## Needed For Correct ID mapping
|
| 24 |
+
remove_columns=self.dataset[temp_key_train].column_names,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
temp_key_test = list(self.test_dataset.keys())[0]
|
| 28 |
+
self.intermediate_test_dataset = self.test_dataset.map(
|
| 29 |
+
self.create_test_features,
|
| 30 |
+
batched=True,
|
| 31 |
+
batch_size=1000000, ##Unusually Large Batch Size ## Needed For Correct ID mapping
|
| 32 |
+
remove_columns=self.test_dataset[temp_key_test].column_names,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
self.tokenized_inputs = self.intermediate_dataset.map(
|
| 36 |
+
self.prepare_train_features,
|
| 37 |
+
batched=True,
|
| 38 |
+
remove_columns=self.intermediate_dataset[temp_key_train].column_names,
|
| 39 |
+
)
|
| 40 |
+
self.test_tokenized_inputs = self.intermediate_test_dataset.map(
|
| 41 |
+
self.prepare_test_features,
|
| 42 |
+
batched=True,
|
| 43 |
+
remove_columns=self.intermediate_test_dataset[temp_key_test].column_names,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def create_train_features(self, examples):
|
| 47 |
+
features = {"context": [], "id": [], "question": [], "title": [], "spans": []}
|
| 48 |
+
id = 0
|
| 49 |
+
# print(examples)
|
| 50 |
+
for row_number in range(len(examples["text"])):
|
| 51 |
+
context = examples["text"][row_number]
|
| 52 |
+
question = "offense"
|
| 53 |
+
title = context.split(" ")[0]
|
| 54 |
+
span = eval(examples["spans"][row_number])
|
| 55 |
+
contiguous_spans = _contiguous_ranges(span)
|
| 56 |
+
for lst in contiguous_spans:
|
| 57 |
+
lst = list(lst)
|
| 58 |
+
dict_to_write = {}
|
| 59 |
+
|
| 60 |
+
dict_to_write["answer_start"] = [lst[0]]
|
| 61 |
+
dict_to_write["text"] = [context[lst[0] : lst[-1] + 1]]
|
| 62 |
+
# print(dict_to_write)
|
| 63 |
+
if "answers" in features.keys():
|
| 64 |
+
features["answers"].append(dict_to_write)
|
| 65 |
+
else:
|
| 66 |
+
features["answers"] = [
|
| 67 |
+
dict_to_write,
|
| 68 |
+
]
|
| 69 |
+
features["context"].append(context)
|
| 70 |
+
features["id"].append(str(id))
|
| 71 |
+
features["question"].append(question)
|
| 72 |
+
features["title"].append(title)
|
| 73 |
+
features["spans"].append(span)
|
| 74 |
+
id += 1
|
| 75 |
+
|
| 76 |
+
return features
|
| 77 |
+
|
| 78 |
+
def create_test_features(self, examples):
|
| 79 |
+
features = {"context": [], "id": [], "question": [], "title": []}
|
| 80 |
+
id = 0
|
| 81 |
+
for row_number in range(len(examples["text"])):
|
| 82 |
+
context = examples["text"][row_number]
|
| 83 |
+
question = "offense"
|
| 84 |
+
title = context.split(" ")[0]
|
| 85 |
+
features["context"].append(context)
|
| 86 |
+
features["id"].append(str(id))
|
| 87 |
+
features["question"].append(question)
|
| 88 |
+
features["title"].append(title)
|
| 89 |
+
id += 1
|
| 90 |
+
return features
|
| 91 |
+
|
| 92 |
+
def prepare_train_features(self, examples):
|
| 93 |
+
"""Generate tokenized features from examples.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
examples (dict): The examples to be tokenized.
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
transformers.tokenization_utils_base.BatchEncoding:
|
| 100 |
+
The tokenized features/examples after processing.
|
| 101 |
+
"""
|
| 102 |
+
# Tokenize our examples with truncation and padding, but keep the
|
| 103 |
+
# overflows using a stride. This results in one example possible
|
| 104 |
+
# giving several features when a context is long, each of those
|
| 105 |
+
# features having a context that overlaps a bit the context
|
| 106 |
+
# of the previous feature.
|
| 107 |
+
pad_on_right = self.tokenizer.padding_side == "right"
|
| 108 |
+
print("### Batch Tokenizing Examples ###")
|
| 109 |
+
tokenized_examples = self.tokenizer(
|
| 110 |
+
examples["question" if pad_on_right else "context"],
|
| 111 |
+
examples["context" if pad_on_right else "question"],
|
| 112 |
+
**dict(self.config.tokenizer_params),
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Since one example might give us several features if it has
|
| 116 |
+
# a long context, we need a map from a feature to
|
| 117 |
+
# its corresponding example. This key gives us just that.
|
| 118 |
+
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
| 119 |
+
# The offset mappings will give us a map from token to
|
| 120 |
+
# character position in the original context. This will
|
| 121 |
+
# help us compute the start_positions and end_positions.
|
| 122 |
+
offset_mapping = tokenized_examples.pop("offset_mapping")
|
| 123 |
+
|
| 124 |
+
# Let's label those examples!
|
| 125 |
+
token_labels = []
|
| 126 |
+
tokenized_examples["start_positions"] = []
|
| 127 |
+
tokenized_examples["end_positions"] = []
|
| 128 |
+
|
| 129 |
+
for i, offsets in enumerate(offset_mapping):
|
| 130 |
+
# We will label impossible answers with the index of the CLS token.
|
| 131 |
+
|
| 132 |
+
token_labels.append([])
|
| 133 |
+
input_ids = tokenized_examples["input_ids"][i]
|
| 134 |
+
spans = examples["spans"][i]
|
| 135 |
+
if self.config.label_cls:
|
| 136 |
+
cls_label = (
|
| 137 |
+
1
|
| 138 |
+
if (
|
| 139 |
+
len(examples["context"][i]) > 0
|
| 140 |
+
and len(spans) / len(examples["context"][i])
|
| 141 |
+
> self.config.cls_threshold
|
| 142 |
+
)
|
| 143 |
+
else 0
|
| 144 |
+
) ## Make class label based on threshold
|
| 145 |
+
else:
|
| 146 |
+
cls_label = -100
|
| 147 |
+
for j, offset in enumerate(offsets):
|
| 148 |
+
if tokenized_examples["input_ids"][i][j] == self.tokenizer.cls_token_id:
|
| 149 |
+
token_labels[-1].append(cls_label)
|
| 150 |
+
elif offset[0] == offset[1] and offset[0] == 0:
|
| 151 |
+
token_labels[-1].append(-100) ## SPECIAL TOKEN
|
| 152 |
+
else:
|
| 153 |
+
toxic_offsets = [x in spans for x in range(offset[0], offset[1])]
|
| 154 |
+
## If any part of the the token is in span, mark it as Toxic
|
| 155 |
+
if (
|
| 156 |
+
len(toxic_offsets) > 0
|
| 157 |
+
and sum(toxic_offsets) / len(toxic_offsets)
|
| 158 |
+
> self.config.token_threshold
|
| 159 |
+
):
|
| 160 |
+
token_labels[-1].append(1)
|
| 161 |
+
else:
|
| 162 |
+
token_labels[-1].append(0)
|
| 163 |
+
|
| 164 |
+
cls_index = input_ids.index(self.tokenizer.cls_token_id)
|
| 165 |
+
|
| 166 |
+
# Grab the sequence corresponding to that example
|
| 167 |
+
# (to know what is the context and what is the question).
|
| 168 |
+
sequence_ids = tokenized_examples.sequence_ids(i)
|
| 169 |
+
|
| 170 |
+
# One example can give several spans, this is the index of
|
| 171 |
+
# the example containing this span of text.
|
| 172 |
+
sample_index = sample_mapping[i]
|
| 173 |
+
answers = examples["answers"][sample_index]
|
| 174 |
+
# If no answers are given, set the cls_index as answer.
|
| 175 |
+
if len(answers["answer_start"]) == 0:
|
| 176 |
+
tokenized_examples["start_positions"].append(cls_index)
|
| 177 |
+
tokenized_examples["end_positions"].append(cls_index)
|
| 178 |
+
else:
|
| 179 |
+
# Start/end character index of the answer in the text.
|
| 180 |
+
start_char = answers["answer_start"][0]
|
| 181 |
+
end_char = start_char + len(answers["text"][0])
|
| 182 |
+
|
| 183 |
+
# Start token index of the current span in the text.
|
| 184 |
+
token_start_index = 0
|
| 185 |
+
while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
|
| 186 |
+
token_start_index += 1
|
| 187 |
+
|
| 188 |
+
# End token index of the current span in the text.
|
| 189 |
+
token_end_index = len(input_ids) - 1
|
| 190 |
+
while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
|
| 191 |
+
token_end_index -= 1
|
| 192 |
+
|
| 193 |
+
# Detect if the answer is out of the span
|
| 194 |
+
# (in which case this feature is labeled with the CLS index).
|
| 195 |
+
if not (
|
| 196 |
+
offsets[token_start_index][0] <= start_char
|
| 197 |
+
and offsets[token_end_index][1] >= end_char
|
| 198 |
+
):
|
| 199 |
+
tokenized_examples["start_positions"].append(cls_index)
|
| 200 |
+
tokenized_examples["end_positions"].append(cls_index)
|
| 201 |
+
else:
|
| 202 |
+
# Otherwise move the token_start_index and
|
| 203 |
+
# stoken_end_index to the two ends of the answer.
|
| 204 |
+
# Note: we could go after the last offset
|
| 205 |
+
# if the answer is the last word (edge case).
|
| 206 |
+
while (
|
| 207 |
+
token_start_index < len(offsets)
|
| 208 |
+
and offsets[token_start_index][0] <= start_char
|
| 209 |
+
):
|
| 210 |
+
token_start_index += 1
|
| 211 |
+
tokenized_examples["start_positions"].append(token_start_index - 1)
|
| 212 |
+
while offsets[token_end_index][1] >= end_char:
|
| 213 |
+
token_end_index -= 1
|
| 214 |
+
tokenized_examples["end_positions"].append(token_end_index + 1)
|
| 215 |
+
tokenized_examples["labels"] = token_labels
|
| 216 |
+
return tokenized_examples
|
| 217 |
+
|
| 218 |
+
def prepare_test_features(self, examples):
|
| 219 |
+
|
| 220 |
+
"""Generate tokenized validation features from examples.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
examples (dict): The validation examples to be tokenized.
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
transformers.tokenization_utils_base.BatchEncoding:
|
| 227 |
+
The tokenized features/examples for validation set after processing.
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
# Tokenize our examples with truncation and maybe
|
| 231 |
+
# padding, but keep the overflows using a stride.
|
| 232 |
+
# This results in one example possible giving several features
|
| 233 |
+
# when a context is long, each of those features having a
|
| 234 |
+
# context that overlaps a bit the context of the previous feature.
|
| 235 |
+
print("### Tokenizing Validation Examples")
|
| 236 |
+
pad_on_right = self.tokenizer.padding_side == "right"
|
| 237 |
+
tokenized_examples = self.tokenizer(
|
| 238 |
+
examples["question" if pad_on_right else "context"],
|
| 239 |
+
examples["context" if pad_on_right else "question"],
|
| 240 |
+
**dict(self.config.tokenizer_params),
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# Since one example might give us several features if it has a long context,
|
| 244 |
+
# we need a map from a feature to its corresponding example. This key gives us just that.
|
| 245 |
+
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
| 246 |
+
|
| 247 |
+
# We keep the example_id that gave us this feature and we will store the offset mappings.
|
| 248 |
+
tokenized_examples["example_id"] = []
|
| 249 |
+
|
| 250 |
+
for i in range(len(tokenized_examples["input_ids"])):
|
| 251 |
+
# Grab the sequence corresponding to that example
|
| 252 |
+
# (to know what is the context and what is the question).
|
| 253 |
+
sequence_ids = tokenized_examples.sequence_ids(i)
|
| 254 |
+
context_index = 1 if pad_on_right else 0
|
| 255 |
+
|
| 256 |
+
# One example can give several spans,
|
| 257 |
+
# this is the index of the example containing this span of text.
|
| 258 |
+
sample_index = sample_mapping[i]
|
| 259 |
+
tokenized_examples["example_id"].append(str(examples["id"][sample_index]))
|
| 260 |
+
|
| 261 |
+
# Set to None the offset_mapping that are not part
|
| 262 |
+
# of the context so it's easy to determine if a token
|
| 263 |
+
# position is part of the context or not.
|
| 264 |
+
tokenized_examples["offset_mapping"][i] = [
|
| 265 |
+
(o if sequence_ids[k] == context_index else None)
|
| 266 |
+
for k, o in enumerate(tokenized_examples["offset_mapping"][i])
|
| 267 |
+
]
|
| 268 |
+
|
| 269 |
+
return tokenized_examples
|
src/models/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.models.auto_models import *
|
| 2 |
+
from src.models.bert_token_spans import *
|
| 3 |
+
from src.models.roberta_token_spans import *
|
| 4 |
+
from src.models.bert_multi_spans import *
|
| 5 |
+
from src.models.roberta_multi_spans import *
|
| 6 |
+
from src.models.bert_crf_token import *
|
| 7 |
+
from src.models.roberta_crf_token import *
|
src/models/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (443 Bytes). View file
|
|
|
src/models/__pycache__/auto_models.cpython-38.pyc
ADDED
|
Binary file (436 Bytes). View file
|
|
|
src/models/__pycache__/bert_crf_token.cpython-38.pyc
ADDED
|
Binary file (1.68 kB). View file
|
|
|
src/models/__pycache__/bert_multi_spans.cpython-38.pyc
ADDED
|
Binary file (1.72 kB). View file
|
|
|
src/models/__pycache__/bert_token_spans.cpython-38.pyc
ADDED
|
Binary file (2.34 kB). View file
|
|
|
src/models/__pycache__/roberta_crf_token.cpython-38.pyc
ADDED
|
Binary file (1.69 kB). View file
|
|
|
src/models/__pycache__/roberta_multi_spans.cpython-38.pyc
ADDED
|
Binary file (1.79 kB). View file
|
|
|
src/models/__pycache__/roberta_token_spans.cpython-38.pyc
ADDED
|
Binary file (2.42 kB). View file
|
|
|
src/models/auto_models.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoModelForTokenClassification, AutoModelForQuestionAnswering
|
| 2 |
+
from src.utils.mapper import configmapper
|
| 3 |
+
|
| 4 |
+
configmapper.map("models", "autotoken")(AutoModelForTokenClassification)
|
| 5 |
+
configmapper.map("models", "autotoken_3cls")(AutoModelForTokenClassification)
|
| 6 |
+
configmapper.map("models", "autospans")(AutoModelForQuestionAnswering)
|
src/models/bert_crf_token.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
# from transformers import BertForTokenClassification
|
| 3 |
+
from transformers import ElectraForTokenClassification
|
| 4 |
+
from torchcrf import CRF
|
| 5 |
+
from src.utils.mapper import configmapper
|
| 6 |
+
# import pdb
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@configmapper.map("models", "bert_crf_token")
|
| 10 |
+
# class BertLSTMCRF(BertForTokenClassification):
|
| 11 |
+
class BertLSTMCRF(ElectraForTokenClassification):
|
| 12 |
+
def __init__(self, config, lstm_hidden_size, lstm_layers):
|
| 13 |
+
super().__init__(config)
|
| 14 |
+
# ipdb.set_trace()
|
| 15 |
+
self.lstm = torch.nn.LSTM(
|
| 16 |
+
input_size=config.hidden_size,
|
| 17 |
+
hidden_size=lstm_hidden_size,
|
| 18 |
+
num_layers=lstm_layers,
|
| 19 |
+
dropout=0.2,
|
| 20 |
+
batch_first=True,
|
| 21 |
+
bidirectional=True,
|
| 22 |
+
)
|
| 23 |
+
self.crf = CRF(config.num_labels, batch_first=True)
|
| 24 |
+
|
| 25 |
+
del self.classifier
|
| 26 |
+
self.classifier = torch.nn.Linear(2 * lstm_hidden_size, config.num_labels)
|
| 27 |
+
|
| 28 |
+
def forward(
|
| 29 |
+
self,
|
| 30 |
+
input_ids,
|
| 31 |
+
attention_mask=None,
|
| 32 |
+
token_type_ids=None,
|
| 33 |
+
labels=None,
|
| 34 |
+
prediction_mask=None,
|
| 35 |
+
):
|
| 36 |
+
# pdb.set_trace()
|
| 37 |
+
|
| 38 |
+
# outputs = self.bert(
|
| 39 |
+
outputs = self.electra(
|
| 40 |
+
input_ids,
|
| 41 |
+
attention_mask,
|
| 42 |
+
token_type_ids,
|
| 43 |
+
output_hidden_states=True,
|
| 44 |
+
return_dict=False,
|
| 45 |
+
)
|
| 46 |
+
# seq_output, all_hidden_states, all_self_attntions, all_cross_attentions
|
| 47 |
+
|
| 48 |
+
sequence_output = outputs[0] # outputs[1] is pooled output which is none.
|
| 49 |
+
|
| 50 |
+
sequence_output = self.dropout(sequence_output)
|
| 51 |
+
|
| 52 |
+
lstm_out, *_ = self.lstm(sequence_output)
|
| 53 |
+
sequence_output = self.dropout(lstm_out)
|
| 54 |
+
|
| 55 |
+
logits = self.classifier(sequence_output)
|
| 56 |
+
|
| 57 |
+
## CRF
|
| 58 |
+
mask = prediction_mask
|
| 59 |
+
mask = mask[:, : logits.size(1)].contiguous()
|
| 60 |
+
|
| 61 |
+
# print(logits)
|
| 62 |
+
|
| 63 |
+
if labels is not None:
|
| 64 |
+
labels = labels[:, : logits.size(1)].contiguous()
|
| 65 |
+
loss = -self.crf(logits, labels, mask=mask.bool(), reduction="token_mean")
|
| 66 |
+
|
| 67 |
+
tags = self.crf.decode(logits, mask.bool())
|
| 68 |
+
# print(tags)
|
| 69 |
+
if labels is not None:
|
| 70 |
+
return (loss, logits, tags)
|
| 71 |
+
else:
|
| 72 |
+
return (logits, tags)
|
src/models/bert_multi_spans.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from torch.nn import BCEWithLogitsLoss
|
| 3 |
+
# from transformers import BertModel, BertPreTrainedModel
|
| 4 |
+
from transformers import ElectraPreTrainedModel, ElectraModel
|
| 5 |
+
from src.utils.mapper import configmapper
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@configmapper.map("models", "bert_multi_spans")
|
| 9 |
+
# class BertForMultiSpans(BertPreTrainedModel):
|
| 10 |
+
class BertForMultiSpans(ElectraPreTrainedModel):
|
| 11 |
+
def __init__(self, config):
|
| 12 |
+
super(BertForMultiSpans, self).__init__(config)
|
| 13 |
+
# self.bert = BertModel(config)
|
| 14 |
+
self.bert = ElectraModel(config)
|
| 15 |
+
self.num_labels = config.num_labels
|
| 16 |
+
|
| 17 |
+
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
|
| 18 |
+
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 19 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
| 20 |
+
self.init_weights()
|
| 21 |
+
|
| 22 |
+
def forward(
|
| 23 |
+
self,
|
| 24 |
+
input_ids=None,
|
| 25 |
+
attention_mask=None,
|
| 26 |
+
token_type_ids=None,
|
| 27 |
+
position_ids=None,
|
| 28 |
+
head_mask=None,
|
| 29 |
+
inputs_embeds=None,
|
| 30 |
+
start_positions=None,
|
| 31 |
+
end_positions=None,
|
| 32 |
+
output_attentions=None,
|
| 33 |
+
output_hidden_states=None,
|
| 34 |
+
):
|
| 35 |
+
outputs = self.bert(
|
| 36 |
+
input_ids,
|
| 37 |
+
attention_mask=attention_mask,
|
| 38 |
+
token_type_ids=token_type_ids,
|
| 39 |
+
position_ids=position_ids,
|
| 40 |
+
head_mask=head_mask,
|
| 41 |
+
inputs_embeds=inputs_embeds,
|
| 42 |
+
output_attentions=output_attentions,
|
| 43 |
+
output_hidden_states=output_hidden_states,
|
| 44 |
+
return_dict=None,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
sequence_output = outputs[0]
|
| 48 |
+
|
| 49 |
+
logits = self.qa_outputs(sequence_output)
|
| 50 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 51 |
+
start_logits = start_logits.squeeze(-1)
|
| 52 |
+
end_logits = end_logits.squeeze(-1) # batch_size
|
| 53 |
+
# print(start_logits.shape, end_logits.shape, start_positions.shape, end_positions.shape)
|
| 54 |
+
|
| 55 |
+
total_loss = None
|
| 56 |
+
if (
|
| 57 |
+
start_positions is not None and end_positions is not None
|
| 58 |
+
): # [batch_size/seq_length]
|
| 59 |
+
# # If we are on multi-GPU, split add a dimension
|
| 60 |
+
# if len(start_positions.size()) > 1:
|
| 61 |
+
# start_positions = start_positions.squeeze(-1)
|
| 62 |
+
# if len(end_positions.size()) > 1:
|
| 63 |
+
# end_positions = end_positions.squeeze(-1)
|
| 64 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 65 |
+
# ignored_index = start_logits.size(1)
|
| 66 |
+
# start_positions.clamp_(0, ignored_index)
|
| 67 |
+
# end_positions.clamp_(0, ignored_index)
|
| 68 |
+
|
| 69 |
+
# start_positions = start_logits.view()
|
| 70 |
+
|
| 71 |
+
loss_fct = BCEWithLogitsLoss()
|
| 72 |
+
|
| 73 |
+
start_loss = loss = loss_fct(
|
| 74 |
+
start_logits,
|
| 75 |
+
start_positions.float(),
|
| 76 |
+
)
|
| 77 |
+
end_loss = loss = loss_fct(
|
| 78 |
+
end_logits,
|
| 79 |
+
end_positions.float(),
|
| 80 |
+
)
|
| 81 |
+
total_loss = (start_loss + end_loss) / 2
|
| 82 |
+
|
| 83 |
+
output = (start_logits, end_logits) + outputs[2:]
|
| 84 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
src/models/bert_token_spans.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
from torch.nn import CrossEntropyLoss
|
| 4 |
+
# from transformers import BertPreTrainedModel, BertModel
|
| 5 |
+
from transformers import ElectraPreTrainedModel, ElectraModel
|
| 6 |
+
from src.utils.mapper import configmapper
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@configmapper.map("models", "bert_token_spans")
|
| 10 |
+
# class BertModelForTokenAndSpans(BertPreTrainedModel):
|
| 11 |
+
class BertModelForTokenAndSpans(ElectraPreTrainedModel):
|
| 12 |
+
def __init__(self, config, num_token_labels=2, num_qa_labels=2):
|
| 13 |
+
super(BertModelForTokenAndSpans, self).__init__(config)
|
| 14 |
+
# self.bert = BertModel(config)
|
| 15 |
+
self.bert = ElectraModel(config)
|
| 16 |
+
self.num_token_labels = num_token_labels
|
| 17 |
+
self.num_qa_labels = num_qa_labels
|
| 18 |
+
# print("Number of Token Labels: ", num_token_labels); exit()
|
| 19 |
+
|
| 20 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 21 |
+
self.classifier = nn.Linear(config.hidden_size, num_token_labels)
|
| 22 |
+
self.qa_outputs = nn.Linear(config.hidden_size, num_qa_labels)
|
| 23 |
+
self.init_weights()
|
| 24 |
+
|
| 25 |
+
def forward(
|
| 26 |
+
self,
|
| 27 |
+
input_ids=None,
|
| 28 |
+
attention_mask=None,
|
| 29 |
+
token_type_ids=None,
|
| 30 |
+
position_ids=None,
|
| 31 |
+
head_mask=None,
|
| 32 |
+
inputs_embeds=None,
|
| 33 |
+
start_positions=None,
|
| 34 |
+
end_positions=None,
|
| 35 |
+
labels=None, # Token Wise Labels
|
| 36 |
+
output_attentions=None,
|
| 37 |
+
output_hidden_states=None,
|
| 38 |
+
):
|
| 39 |
+
|
| 40 |
+
outputs = self.bert(
|
| 41 |
+
input_ids,
|
| 42 |
+
attention_mask=attention_mask,
|
| 43 |
+
token_type_ids=token_type_ids,
|
| 44 |
+
position_ids=position_ids,
|
| 45 |
+
head_mask=head_mask,
|
| 46 |
+
inputs_embeds=inputs_embeds,
|
| 47 |
+
output_attentions=output_attentions,
|
| 48 |
+
output_hidden_states=output_hidden_states,
|
| 49 |
+
return_dict=None,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
sequence_output = outputs[0]
|
| 53 |
+
|
| 54 |
+
qa_logits = self.qa_outputs(sequence_output)
|
| 55 |
+
start_logits, end_logits = qa_logits.split(1, dim=-1)
|
| 56 |
+
start_logits = start_logits.squeeze(-1)
|
| 57 |
+
end_logits = end_logits.squeeze(-1)
|
| 58 |
+
|
| 59 |
+
sequence_output = self.dropout(sequence_output)
|
| 60 |
+
token_logits = self.classifier(sequence_output)
|
| 61 |
+
|
| 62 |
+
total_loss = None
|
| 63 |
+
if (
|
| 64 |
+
start_positions is not None
|
| 65 |
+
and end_positions is not None
|
| 66 |
+
and labels is not None
|
| 67 |
+
):
|
| 68 |
+
# If we are on multi-GPU, split add a dimension
|
| 69 |
+
if len(start_positions.size()) > 1:
|
| 70 |
+
start_positions = start_positions.squeeze(-1)
|
| 71 |
+
if len(end_positions.size()) > 1:
|
| 72 |
+
end_positions = end_positions.squeeze(-1)
|
| 73 |
+
|
| 74 |
+
ignored_index = start_logits.size(1)
|
| 75 |
+
start_positions.clamp_(0, ignored_index)
|
| 76 |
+
end_positions.clamp_(0, ignored_index)
|
| 77 |
+
|
| 78 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 79 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 80 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 81 |
+
|
| 82 |
+
loss_fct = CrossEntropyLoss()
|
| 83 |
+
if attention_mask is not None:
|
| 84 |
+
active_loss = attention_mask.view(-1) == 1
|
| 85 |
+
active_logits = token_logits.view(-1, self.num_token_labels)
|
| 86 |
+
active_labels = torch.where(
|
| 87 |
+
active_loss,
|
| 88 |
+
labels.view(-1),
|
| 89 |
+
torch.tensor(loss_fct.ignore_index).type_as(labels),
|
| 90 |
+
)
|
| 91 |
+
token_loss = loss_fct(active_logits, active_labels)
|
| 92 |
+
else:
|
| 93 |
+
token_loss = loss_fct(
|
| 94 |
+
token_logits.view(-1, self.num_token_labels), labels.view(-1)
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
total_loss = (start_loss + end_loss) / 2 + token_loss
|
| 98 |
+
|
| 99 |
+
output = (start_logits, end_logits, token_logits) + outputs[2:]
|
| 100 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
src/models/roberta_crf_token.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import RobertaForTokenClassification
|
| 3 |
+
from torchcrf import CRF
|
| 4 |
+
from src.utils.mapper import configmapper
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@configmapper.map("models", "roberta_crf_token")
|
| 8 |
+
class RobertaLSTMCRF(RobertaForTokenClassification):
|
| 9 |
+
def __init__(self, config, lstm_hidden_size, lstm_layers):
|
| 10 |
+
super().__init__(config)
|
| 11 |
+
self.lstm = torch.nn.LSTM(
|
| 12 |
+
input_size=config.hidden_size,
|
| 13 |
+
hidden_size=lstm_hidden_size,
|
| 14 |
+
num_layers=lstm_layers,
|
| 15 |
+
dropout=0.2,
|
| 16 |
+
batch_first=True,
|
| 17 |
+
bidirectional=True,
|
| 18 |
+
)
|
| 19 |
+
self.crf = CRF(config.num_labels, batch_first=True)
|
| 20 |
+
|
| 21 |
+
del self.classifier
|
| 22 |
+
self.classifier = torch.nn.Linear(2 * lstm_hidden_size, config.num_labels)
|
| 23 |
+
|
| 24 |
+
def forward(
|
| 25 |
+
self,
|
| 26 |
+
input_ids,
|
| 27 |
+
attention_mask=None,
|
| 28 |
+
token_type_ids=None,
|
| 29 |
+
labels=None,
|
| 30 |
+
prediction_mask=None,
|
| 31 |
+
):
|
| 32 |
+
|
| 33 |
+
outputs = self.roberta(
|
| 34 |
+
input_ids,
|
| 35 |
+
attention_mask,
|
| 36 |
+
token_type_ids,
|
| 37 |
+
output_hidden_states=True,
|
| 38 |
+
return_dict=False,
|
| 39 |
+
)
|
| 40 |
+
# seq_output, all_hidden_states, all_self_attntions, all_cross_attentions
|
| 41 |
+
|
| 42 |
+
sequence_output = outputs[0] # outputs[1] is pooled output which is none.
|
| 43 |
+
|
| 44 |
+
sequence_output = self.dropout(sequence_output)
|
| 45 |
+
|
| 46 |
+
lstm_out, *_ = self.lstm(sequence_output)
|
| 47 |
+
sequence_output = self.dropout(lstm_out)
|
| 48 |
+
|
| 49 |
+
logits = self.classifier(sequence_output)
|
| 50 |
+
|
| 51 |
+
## CRF
|
| 52 |
+
mask = prediction_mask
|
| 53 |
+
mask = mask[:, : logits.size(1)].contiguous()
|
| 54 |
+
|
| 55 |
+
# print(logits)
|
| 56 |
+
|
| 57 |
+
if labels is not None:
|
| 58 |
+
labels = labels[:, : logits.size(1)].contiguous()
|
| 59 |
+
loss = -self.crf(logits, labels, mask=mask.bool(), reduction="token_mean")
|
| 60 |
+
|
| 61 |
+
tags = self.crf.decode(logits, mask.bool())
|
| 62 |
+
# print(tags)
|
| 63 |
+
if labels is not None:
|
| 64 |
+
return (loss, logits, tags)
|
| 65 |
+
else:
|
| 66 |
+
return (logits, tags)
|
src/models/roberta_multi_spans.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from torch.nn import BCEWithLogitsLoss
|
| 3 |
+
from transformers import RobertaModel
|
| 4 |
+
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel
|
| 5 |
+
from src.utils.mapper import configmapper
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@configmapper.map("models", "roberta_multi_spans")
|
| 9 |
+
class RobertaForMultiSpans(RobertaPreTrainedModel):
|
| 10 |
+
def __init__(self, config):
|
| 11 |
+
super(RobertaForMultiSpans, self).__init__(config)
|
| 12 |
+
self.roberta = RobertaModel(config)
|
| 13 |
+
self.num_labels = config.num_labels
|
| 14 |
+
|
| 15 |
+
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
|
| 16 |
+
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 17 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
| 18 |
+
self.init_weights()
|
| 19 |
+
|
| 20 |
+
def forward(
|
| 21 |
+
self,
|
| 22 |
+
input_ids=None,
|
| 23 |
+
attention_mask=None,
|
| 24 |
+
token_type_ids=None,
|
| 25 |
+
position_ids=None,
|
| 26 |
+
head_mask=None,
|
| 27 |
+
inputs_embeds=None,
|
| 28 |
+
start_positions=None,
|
| 29 |
+
end_positions=None,
|
| 30 |
+
output_attentions=None,
|
| 31 |
+
output_hidden_states=None,
|
| 32 |
+
):
|
| 33 |
+
outputs = self.roberta(
|
| 34 |
+
input_ids,
|
| 35 |
+
attention_mask=attention_mask,
|
| 36 |
+
token_type_ids=token_type_ids,
|
| 37 |
+
position_ids=position_ids,
|
| 38 |
+
head_mask=head_mask,
|
| 39 |
+
inputs_embeds=inputs_embeds,
|
| 40 |
+
output_attentions=output_attentions,
|
| 41 |
+
output_hidden_states=output_hidden_states,
|
| 42 |
+
return_dict=None,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
sequence_output = outputs[0]
|
| 46 |
+
|
| 47 |
+
logits = self.qa_outputs(sequence_output)
|
| 48 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 49 |
+
start_logits = start_logits.squeeze(-1)
|
| 50 |
+
end_logits = end_logits.squeeze(-1) # batch_size
|
| 51 |
+
# print(start_logits.shape, end_logits.shape, start_positions.shape, end_positions.shape)
|
| 52 |
+
|
| 53 |
+
total_loss = None
|
| 54 |
+
if (
|
| 55 |
+
start_positions is not None and end_positions is not None
|
| 56 |
+
): # [batch_size/seq_length]
|
| 57 |
+
# # If we are on multi-GPU, split add a dimension
|
| 58 |
+
# if len(start_positions.size()) > 1:
|
| 59 |
+
# start_positions = start_positions.squeeze(-1)
|
| 60 |
+
# if len(end_positions.size()) > 1:
|
| 61 |
+
# end_positions = end_positions.squeeze(-1)
|
| 62 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 63 |
+
# ignored_index = start_logits.size(1)
|
| 64 |
+
# start_positions.clamp_(0, ignored_index)
|
| 65 |
+
# end_positions.clamp_(0, ignored_index)
|
| 66 |
+
|
| 67 |
+
# start_positions = start_logits.view()
|
| 68 |
+
|
| 69 |
+
loss_fct = BCEWithLogitsLoss()
|
| 70 |
+
|
| 71 |
+
start_loss = loss = loss_fct(
|
| 72 |
+
start_logits,
|
| 73 |
+
start_positions.float(),
|
| 74 |
+
)
|
| 75 |
+
end_loss = loss = loss_fct(
|
| 76 |
+
end_logits,
|
| 77 |
+
end_positions.float(),
|
| 78 |
+
)
|
| 79 |
+
total_loss = (start_loss + end_loss) / 2
|
| 80 |
+
|
| 81 |
+
output = (start_logits, end_logits) + outputs[2:]
|
| 82 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
src/models/roberta_token_spans.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
from torch.nn import CrossEntropyLoss
|
| 4 |
+
from transformers import RobertaModel
|
| 5 |
+
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel
|
| 6 |
+
from src.utils.mapper import configmapper
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@configmapper.map("models", "roberta_token_spans")
|
| 10 |
+
class RobertaModelForTokenAndSpans(RobertaPreTrainedModel):
|
| 11 |
+
def __init__(self, config, num_token_labels=2, num_qa_labels=2):
|
| 12 |
+
super(RobertaModelForTokenAndSpans, self).__init__(config)
|
| 13 |
+
self.roberta = RobertaModel(config)
|
| 14 |
+
self.num_token_labels = num_token_labels
|
| 15 |
+
self.num_qa_labels = num_qa_labels
|
| 16 |
+
|
| 17 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 18 |
+
self.classifier = nn.Linear(config.hidden_size, num_token_labels)
|
| 19 |
+
self.qa_outputs = nn.Linear(config.hidden_size, num_qa_labels)
|
| 20 |
+
self.init_weights()
|
| 21 |
+
|
| 22 |
+
def forward(
|
| 23 |
+
self,
|
| 24 |
+
input_ids=None,
|
| 25 |
+
attention_mask=None,
|
| 26 |
+
token_type_ids=None,
|
| 27 |
+
position_ids=None,
|
| 28 |
+
head_mask=None,
|
| 29 |
+
inputs_embeds=None,
|
| 30 |
+
start_positions=None,
|
| 31 |
+
end_positions=None,
|
| 32 |
+
labels=None, # Token Wise Labels
|
| 33 |
+
output_attentions=None,
|
| 34 |
+
output_hidden_states=None,
|
| 35 |
+
):
|
| 36 |
+
|
| 37 |
+
outputs = self.roberta(
|
| 38 |
+
input_ids,
|
| 39 |
+
attention_mask=attention_mask,
|
| 40 |
+
token_type_ids=token_type_ids,
|
| 41 |
+
position_ids=position_ids,
|
| 42 |
+
head_mask=head_mask,
|
| 43 |
+
inputs_embeds=inputs_embeds,
|
| 44 |
+
output_attentions=output_attentions,
|
| 45 |
+
output_hidden_states=output_hidden_states,
|
| 46 |
+
return_dict=None,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
sequence_output = outputs[0]
|
| 50 |
+
|
| 51 |
+
qa_logits = self.qa_outputs(sequence_output)
|
| 52 |
+
start_logits, end_logits = qa_logits.split(1, dim=-1)
|
| 53 |
+
start_logits = start_logits.squeeze(-1)
|
| 54 |
+
end_logits = end_logits.squeeze(-1)
|
| 55 |
+
|
| 56 |
+
sequence_output = self.dropout(sequence_output)
|
| 57 |
+
token_logits = self.classifier(sequence_output)
|
| 58 |
+
|
| 59 |
+
total_loss = None
|
| 60 |
+
if (
|
| 61 |
+
start_positions is not None
|
| 62 |
+
and end_positions is not None
|
| 63 |
+
and labels is not None
|
| 64 |
+
):
|
| 65 |
+
# If we are on multi-GPU, split add a dimension
|
| 66 |
+
if len(start_positions.size()) > 1:
|
| 67 |
+
start_positions = start_positions.squeeze(-1)
|
| 68 |
+
if len(end_positions.size()) > 1:
|
| 69 |
+
end_positions = end_positions.squeeze(-1)
|
| 70 |
+
|
| 71 |
+
ignored_index = start_logits.size(1)
|
| 72 |
+
start_positions.clamp_(0, ignored_index)
|
| 73 |
+
end_positions.clamp_(0, ignored_index)
|
| 74 |
+
|
| 75 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 76 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 77 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 78 |
+
|
| 79 |
+
loss_fct = CrossEntropyLoss()
|
| 80 |
+
if attention_mask is not None:
|
| 81 |
+
active_loss = attention_mask.view(-1) == 1
|
| 82 |
+
active_logits = token_logits.view(-1, self.num_token_labels)
|
| 83 |
+
active_labels = torch.where(
|
| 84 |
+
active_loss,
|
| 85 |
+
labels.view(-1),
|
| 86 |
+
torch.tensor(loss_fct.ignore_index).type_as(labels),
|
| 87 |
+
)
|
| 88 |
+
token_loss = loss_fct(active_logits, active_labels)
|
| 89 |
+
else:
|
| 90 |
+
token_loss = loss_fct(
|
| 91 |
+
token_logits.view(-1, self.num_token_labels), labels.view(-1)
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
total_loss = (start_loss + end_loss) / 2 + token_loss
|
| 95 |
+
|
| 96 |
+
output = (start_logits, end_logits, token_logits) + outputs[2:]
|
| 97 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
src/models/two_layer_nn.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Implements a two layer Neural Network."""
|
| 2 |
+
|
| 3 |
+
from torch.nn import Module, Linear, ReLU
|
| 4 |
+
from src.utils.mapper import configmapper
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@configmapper.map("models", "two_layer_nn")
|
| 8 |
+
class TwoLayerNN(Module):
|
| 9 |
+
"""Implements two layer neural network.
|
| 10 |
+
|
| 11 |
+
Methods:
|
| 12 |
+
forward(x_input): Returns the output of the neural network.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, embedding, dims):
|
| 16 |
+
"""Construct the two layer Neural Network.
|
| 17 |
+
|
| 18 |
+
This method is used to initialize the two layer neural network,
|
| 19 |
+
with a given embedding type and corresponding arguments.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
embedding (torch.nn.Module): The embedding layer for the model.
|
| 23 |
+
dims (list): List of dimensions for the neural network, input to output.
|
| 24 |
+
"""
|
| 25 |
+
super(TwoLayerNN, self).__init__()
|
| 26 |
+
|
| 27 |
+
self.embedding = embedding
|
| 28 |
+
self.linear1 = Linear(dims[0], dims[1])
|
| 29 |
+
self.relu = ReLU()
|
| 30 |
+
self.linear2 = Linear(dims[1], dims[2])
|
| 31 |
+
|
| 32 |
+
def forward(self, x_input):
|
| 33 |
+
"""
|
| 34 |
+
Return the output of the neural network for an input.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
x_input (torch.Tensor): The input tensor to the neural network.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
x_output (torch.Tensor): The output tensor for the neural network.
|
| 41 |
+
"""
|
| 42 |
+
output = self.embedding(x_input)
|
| 43 |
+
output = self.linear1(output)
|
| 44 |
+
output = self.relu(output)
|
| 45 |
+
x_output = self.linear2(output)
|
| 46 |
+
return x_output
|
src/modules/__init__.py
ADDED
|
File without changes
|
src/modules/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (166 Bytes). View file
|
|
|
src/modules/__pycache__/embeddings.cpython-38.pyc
ADDED
|
Binary file (1.67 kB). View file
|
|
|
src/modules/__pycache__/preprocessors.cpython-38.pyc
ADDED
|
Binary file (3.42 kB). View file
|
|
|
src/modules/__pycache__/tokenizers.cpython-38.pyc
ADDED
|
Binary file (4.87 kB). View file
|
|
|
src/modules/activations.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from src.utils.mapper import configmapper
|
| 3 |
+
|
| 4 |
+
configmapper.map("activations", "relu")(nn.ReLU)
|
| 5 |
+
configmapper.map("activations", "logsoftmax")(nn.LogSoftmax)
|
| 6 |
+
configmapper.map("activations", "softmax")(nn.Softmax)
|
src/modules/embeddings.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains various kinds of embeddings like Glove, BERT, etc."""
|
| 2 |
+
|
| 3 |
+
from torch.nn import Module, Embedding, Flatten
|
| 4 |
+
from src.utils.mapper import configmapper
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@configmapper.map("embeddings", "glove")
|
| 8 |
+
class GloveEmbedding(Module):
|
| 9 |
+
"""Implement Glove based Word Embedding."""
|
| 10 |
+
|
| 11 |
+
def __init__(self, embedding_matrix, padding_idx, static=True):
|
| 12 |
+
"""Construct GloveEmbedding.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
embedding_matrix (torch.Tensor): The matrix contrainining the embedding weights
|
| 16 |
+
padding_idx (int): The padding index in the tokenizer.
|
| 17 |
+
static (bool): Whether or not to freeze embeddings.
|
| 18 |
+
"""
|
| 19 |
+
super(GloveEmbedding, self).__init__()
|
| 20 |
+
self.embedding = Embedding.from_pretrained(embedding_matrix)
|
| 21 |
+
self.embedding.padding_idx = padding_idx
|
| 22 |
+
if static:
|
| 23 |
+
self.embedding.weight.required_grad = False
|
| 24 |
+
self.flatten = Flatten(start_dim=1)
|
| 25 |
+
|
| 26 |
+
def forward(self, x_input):
|
| 27 |
+
"""Pass the input through the embedding.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
x_input (torch.Tensor): The numericalized tokenized input
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
x_output (torch.Tensor): The output from the embedding
|
| 34 |
+
"""
|
| 35 |
+
x_output = self.embedding(x_input)
|
| 36 |
+
x_output = self.flatten(x_output)
|
| 37 |
+
return x_output
|
src/modules/losses.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"All criterion functions."
|
| 2 |
+
from torch.nn import MSELoss, CrossEntropyLoss
|
| 3 |
+
from src.utils.mapper import configmapper
|
| 4 |
+
|
| 5 |
+
configmapper.map("losses", "mse")(MSELoss)
|
| 6 |
+
configmapper.map("losses", "CrossEntropyLoss")(CrossEntropyLoss)
|
src/modules/metrics.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Metrics."""
|
| 2 |
+
from sklearn.metrics import (
|
| 3 |
+
mean_squared_error,
|
| 4 |
+
f1_score,
|
| 5 |
+
precision_score,
|
| 6 |
+
recall_score,
|
| 7 |
+
roc_auc_score,
|
| 8 |
+
accuracy_score,
|
| 9 |
+
)
|
| 10 |
+
from src.utils.mapper import configmapper
|
| 11 |
+
|
| 12 |
+
configmapper.map("metrics", "sklearn_f1")(f1_score)
|
| 13 |
+
configmapper.map("metrics", "sklearn_p")(precision_score)
|
| 14 |
+
configmapper.map("metrics", "sklearn_r")(recall_score)
|
| 15 |
+
configmapper.map("metrics", "sklearn_roc")(roc_auc_score)
|
| 16 |
+
configmapper.map("metrics", "sklearn_acc")(accuracy_score)
|
| 17 |
+
configmapper.map("metrics", "sklearn_mse")(mean_squared_error)
|
src/modules/optimizers.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
" Method containing activation functions"
|
| 2 |
+
from torch.optim import Adam, AdamW, SGD
|
| 3 |
+
from src.utils.mapper import configmapper
|
| 4 |
+
|
| 5 |
+
configmapper.map("optimizers", "adam")(Adam)
|
| 6 |
+
configmapper.map("optimizers", "adam_w")(AdamW)
|
| 7 |
+
configmapper.map("optimizers", "sgd")(SGD)
|
src/modules/preprocessors.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.modules.tokenizers import *
|
| 2 |
+
from src.modules.embeddings import *
|
| 3 |
+
from src.utils.mapper import configmapper
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Preprocessor:
|
| 7 |
+
def preprocess(self):
|
| 8 |
+
pass
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@configmapper.map("preprocessors", "glove")
|
| 12 |
+
class GlovePreprocessor(Preprocessor):
|
| 13 |
+
"""GlovePreprocessor."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, config):
|
| 16 |
+
"""
|
| 17 |
+
Args:
|
| 18 |
+
config (src.utils.module.Config): configuration for preprocessor
|
| 19 |
+
"""
|
| 20 |
+
super(GlovePreprocessor, self).__init__()
|
| 21 |
+
self.config = config
|
| 22 |
+
self.tokenizer = configmapper.get_object(
|
| 23 |
+
"tokenizers", self.config.main.preprocessor.tokenizer.name
|
| 24 |
+
)(**self.config.main.preprocessor.tokenizer.init_params.as_dict())
|
| 25 |
+
self.tokenizer_params = (
|
| 26 |
+
self.config.main.preprocessor.tokenizer.init_vector_params.as_dict()
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
self.tokenizer.initialize_vectors(**self.tokenizer_params)
|
| 30 |
+
self.embeddings = configmapper.get_object(
|
| 31 |
+
"embeddings", self.config.main.preprocessor.embedding.name
|
| 32 |
+
)(
|
| 33 |
+
self.tokenizer.text_field.vocab.vectors,
|
| 34 |
+
self.tokenizer.text_field.vocab.stoi[self.tokenizer.text_field.pad_token],
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def preprocess(self, model_config, data_config):
|
| 38 |
+
train_dataset = configmapper.get_object("datasets", data_config.main.name)(
|
| 39 |
+
data_config.train, self.tokenizer
|
| 40 |
+
)
|
| 41 |
+
val_dataset = configmapper.get_object("datasets", data_config.main.name)(
|
| 42 |
+
data_config.val, self.tokenizer
|
| 43 |
+
)
|
| 44 |
+
model = configmapper.get_object("models", model_config.name)(
|
| 45 |
+
self.embeddings, **model_config.params.as_dict()
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
return model, train_dataset, val_dataset
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@configmapper.map("preprocessors", "clozePreprocessor")
|
| 52 |
+
class ClozePreprocessor(Preprocessor):
|
| 53 |
+
"""GlovePreprocessor."""
|
| 54 |
+
|
| 55 |
+
def __init__(self, config):
|
| 56 |
+
"""
|
| 57 |
+
Args:
|
| 58 |
+
config (src.utils.module.Config): configuration for preprocessor
|
| 59 |
+
"""
|
| 60 |
+
super(ClozePreprocessor, self).__init__()
|
| 61 |
+
self.config = config
|
| 62 |
+
self.tokenizer = configmapper.get_object(
|
| 63 |
+
"tokenizers", self.config.main.preprocessor.tokenizer.name
|
| 64 |
+
).from_pretrained(
|
| 65 |
+
**self.config.main.preprocessor.tokenizer.init_params.as_dict()
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def preprocess(self, model_config, data_config):
|
| 69 |
+
train_dataset = configmapper.get_object("datasets", data_config.main.name)(
|
| 70 |
+
data_config.train, self.tokenizer
|
| 71 |
+
)
|
| 72 |
+
val_dataset = configmapper.get_object("datasets", data_config.main.name)(
|
| 73 |
+
data_config.val, self.tokenizer
|
| 74 |
+
)
|
| 75 |
+
model = configmapper.get_object("models", model_config.name).from_pretrained(
|
| 76 |
+
**model_config.params.as_dict()
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
return model, train_dataset, val_dataset
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@configmapper.map("preprocessors", "transformersConcretenessPreprocessor")
|
| 83 |
+
class TransformersConcretenessPreprocessor(Preprocessor):
|
| 84 |
+
"""BertConcretenessPreprocessor."""
|
| 85 |
+
|
| 86 |
+
def __init__(self, config):
|
| 87 |
+
"""
|
| 88 |
+
Args:
|
| 89 |
+
config (src.utils.module.Config): configuration for preprocessor
|
| 90 |
+
"""
|
| 91 |
+
super(TransformersConcretenessPreprocessor, self).__init__()
|
| 92 |
+
self.config = config
|
| 93 |
+
self.tokenizer = configmapper.get_object(
|
| 94 |
+
"tokenizers", self.config.main.preprocessor.tokenizer.name
|
| 95 |
+
).from_pretrained(
|
| 96 |
+
**self.config.main.preprocessor.tokenizer.init_params.as_dict()
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def preprocess(self, model_config, data_config):
|
| 100 |
+
|
| 101 |
+
train_dataset = configmapper.get_object("datasets", data_config.main.name)(
|
| 102 |
+
data_config.train, self.tokenizer
|
| 103 |
+
)
|
| 104 |
+
val_dataset = configmapper.get_object("datasets", data_config.main.name)(
|
| 105 |
+
data_config.val, self.tokenizer
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
model = configmapper.get_object("models", model_config.name)(
|
| 109 |
+
**model_config.params.as_dict()
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
return model, train_dataset, val_dataset
|
src/modules/schedulers.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.optim.lr_scheduler import (
|
| 2 |
+
StepLR,
|
| 3 |
+
CosineAnnealingLR,
|
| 4 |
+
ReduceLROnPlateau,
|
| 5 |
+
CyclicLR,
|
| 6 |
+
CosineAnnealingWarmRestarts,
|
| 7 |
+
)
|
| 8 |
+
from src.utils.mapper import configmapper
|
| 9 |
+
|
| 10 |
+
configmapper.map("schedulers", "step")(StepLR)
|
| 11 |
+
configmapper.map("schedulers", "cosineanneal")(CosineAnnealingLR)
|
| 12 |
+
configmapper.map("schedulers", "reduceplateau")(ReduceLROnPlateau)
|
| 13 |
+
configmapper.map("schedulers", "cyclic")(CyclicLR)
|
| 14 |
+
configmapper.map("schedulers", "cosineannealrestart")(CosineAnnealingWarmRestarts)
|
src/modules/tokenizers.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains tokenizers like GloveTokenizers and BERT Tokenizer."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
# from torchtext.vocab import GloVe
|
| 5 |
+
# from torchtext.data import Field, TabularDataset
|
| 6 |
+
from src.utils.mapper import configmapper
|
| 7 |
+
from transformers import AutoTokenizer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Tokenizer:
|
| 11 |
+
"""Abstract Class for Tokenizers."""
|
| 12 |
+
|
| 13 |
+
def tokenize(self):
|
| 14 |
+
"""Abstract Method for tokenization."""
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@configmapper.map("tokenizers", "glove")
|
| 18 |
+
class GloveTokenizer(Tokenizer):
|
| 19 |
+
"""Implement GloveTokenizer for tokenizing text for Glove Embeddings.
|
| 20 |
+
|
| 21 |
+
Attributes:
|
| 22 |
+
embeddings (torchtext.vocab.Vectors): Loaded pre-trained embeddings.
|
| 23 |
+
text_field (torchtext.data.Field): Text_field for vector creation.
|
| 24 |
+
|
| 25 |
+
Methods:
|
| 26 |
+
__init__(self, name='840B', dim='300', cache='../embeddings/') : Constructor method
|
| 27 |
+
initialize_vectors(fix_length=4, tokenize='spacy', file_path="../data/imperceptibility
|
| 28 |
+
/Concreteness Ratings/train/forty.csv",
|
| 29 |
+
file_format='tsv', fields=None): Initialize vocab vectors based on data.
|
| 30 |
+
|
| 31 |
+
tokenize(x_input, **initializer_params): Tokenize given input and return the output.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, name="840B", dim="300", cache="../embeddings/"):
|
| 35 |
+
"""Construct GloveTokenizer.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
name (str): Name of the GloVe embedding file
|
| 39 |
+
dim (str): Dimensions of the Glove embedding file
|
| 40 |
+
cache (str): Path to the embeddings directory
|
| 41 |
+
"""
|
| 42 |
+
super(GloveTokenizer, self).__init__()
|
| 43 |
+
self.embeddings = GloVe(name=name, dim=dim, cache=cache)
|
| 44 |
+
self.text_field = None
|
| 45 |
+
|
| 46 |
+
def initialize_vectors(
|
| 47 |
+
self,
|
| 48 |
+
fix_length=4,
|
| 49 |
+
tokenize="spacy",
|
| 50 |
+
tokenizer_file_paths=None,
|
| 51 |
+
file_format="tsv",
|
| 52 |
+
fields=None,
|
| 53 |
+
):
|
| 54 |
+
"""Initialize words/sequences based on GloVe embedding.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
fields (list): The list containing the fields to be taken
|
| 58 |
+
and processed from the file (see documentation for
|
| 59 |
+
torchtext.data.TabularDataset)
|
| 60 |
+
fix_length (int): The length of the tokenized text,
|
| 61 |
+
padding or cropping is done accordingly
|
| 62 |
+
tokenize (function or string): Method to tokenize the data.
|
| 63 |
+
If 'spacy' uses spacy tokenizer,
|
| 64 |
+
else the specified method.
|
| 65 |
+
tokenizer_file_paths (list of str): The paths of the files containing the data
|
| 66 |
+
format (str): The format of the file : 'csv', 'tsv' or 'json'
|
| 67 |
+
"""
|
| 68 |
+
text_field = Field(batch_first=True, fix_length=fix_length, tokenize=tokenize)
|
| 69 |
+
tab_dats = [
|
| 70 |
+
TabularDataset(
|
| 71 |
+
i, format=file_format, fields={k: (k, text_field) for k in fields}
|
| 72 |
+
)
|
| 73 |
+
for i in tokenizer_file_paths
|
| 74 |
+
]
|
| 75 |
+
text_field.build_vocab(*tab_dats)
|
| 76 |
+
text_field.vocab.load_vectors(self.embeddings)
|
| 77 |
+
self.text_field = text_field
|
| 78 |
+
|
| 79 |
+
def tokenize(self, x_input, **init_vector__params):
|
| 80 |
+
"""Tokenize given input based on initialized vectors.
|
| 81 |
+
|
| 82 |
+
Initialize the vectors with given parameters if not already initialized.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
x_input (str): Unprocessed input text to be tokenized
|
| 86 |
+
**initializer_params (Keyword arguments): Parameters to initialize vectors
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
x_output (str): Processed and tokenized text
|
| 90 |
+
"""
|
| 91 |
+
if self.text_field is None:
|
| 92 |
+
self.initialize_vectors(**init_vector__params)
|
| 93 |
+
try:
|
| 94 |
+
x_output = torch.squeeze(
|
| 95 |
+
self.text_field.process([self.text_field.preprocess(x_input)])
|
| 96 |
+
)
|
| 97 |
+
except Exception as e:
|
| 98 |
+
print(x_input)
|
| 99 |
+
print(self.text_field.preprocess(x_input))
|
| 100 |
+
print(e)
|
| 101 |
+
return x_output
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@configmapper.map("tokenizers", "AutoTokenizer")
|
| 105 |
+
class AutoTokenizer(AutoTokenizer):
|
| 106 |
+
def __init__(self, *args):
|
| 107 |
+
super(AutoTokenizer, self).__init__()
|
src/trainers/__init__.py
ADDED
|
File without changes
|
src/trainers/base_trainer.py
ADDED
|
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from src.modules.optimizers import *
|
| 5 |
+
from src.modules.embeddings import *
|
| 6 |
+
from src.modules.schedulers import *
|
| 7 |
+
from src.modules.tokenizers import *
|
| 8 |
+
from src.modules.metrics import *
|
| 9 |
+
from src.modules.losses import *
|
| 10 |
+
from src.utils.misc import *
|
| 11 |
+
from src.utils.logger import Logger
|
| 12 |
+
from src.utils.mapper import configmapper
|
| 13 |
+
from src.utils.configuration import Config
|
| 14 |
+
|
| 15 |
+
from torch.utils.data import DataLoader
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@configmapper.map("trainers", "base")
|
| 20 |
+
class BaseTrainer:
|
| 21 |
+
def __init__(self, config):
|
| 22 |
+
self._config = config
|
| 23 |
+
self.metrics = {
|
| 24 |
+
configmapper.get_object("metrics", metric["type"]): metric["params"]
|
| 25 |
+
for metric in self._config.main_config.metrics
|
| 26 |
+
}
|
| 27 |
+
self.train_config = self._config.train
|
| 28 |
+
self.val_config = self._config.val
|
| 29 |
+
self.log_label = self.train_config.log.log_label
|
| 30 |
+
if self.train_config.log_and_val_interval is not None:
|
| 31 |
+
self.val_log_together = True
|
| 32 |
+
print("Logging with label: ", self.log_label)
|
| 33 |
+
|
| 34 |
+
def train(self, model, train_dataset, val_dataset=None, logger=None):
|
| 35 |
+
device = torch.device(self._config.main_config.device.name)
|
| 36 |
+
model.to(device)
|
| 37 |
+
optim_params = self.train_config.optimizer.params
|
| 38 |
+
if optim_params:
|
| 39 |
+
optimizer = configmapper.get_object(
|
| 40 |
+
"optimizers", self.train_config.optimizer.type
|
| 41 |
+
)(model.parameters(), **optim_params.as_dict())
|
| 42 |
+
else:
|
| 43 |
+
optimizer = configmapper.get_object(
|
| 44 |
+
"optimizers", self.train_config.optimizer.type
|
| 45 |
+
)(model.parameters())
|
| 46 |
+
|
| 47 |
+
if self.train_config.scheduler is not None:
|
| 48 |
+
scheduler_params = self.train_config.scheduler.params
|
| 49 |
+
if scheduler_params:
|
| 50 |
+
scheduler = configmapper.get_object(
|
| 51 |
+
"schedulers", self.train_config.scheduler.type
|
| 52 |
+
)(optimizer, **scheduler_params.as_dict())
|
| 53 |
+
else:
|
| 54 |
+
scheduler = configmapper.get_object(
|
| 55 |
+
"schedulers", self.train_config.scheduler.type
|
| 56 |
+
)(optimizer)
|
| 57 |
+
|
| 58 |
+
criterion_params = self.train_config.criterion.params
|
| 59 |
+
if criterion_params:
|
| 60 |
+
criterion = configmapper.get_object(
|
| 61 |
+
"losses", self.train_config.criterion.type
|
| 62 |
+
)(**criterion_params.as_dict())
|
| 63 |
+
else:
|
| 64 |
+
criterion = configmapper.get_object(
|
| 65 |
+
"losses", self.train_config.criterion.type
|
| 66 |
+
)()
|
| 67 |
+
if "custom_collate_fn" in dir(train_dataset):
|
| 68 |
+
train_loader = DataLoader(
|
| 69 |
+
dataset=train_dataset,
|
| 70 |
+
collate_fn=train_dataset.custom_collate_fn,
|
| 71 |
+
**self.train_config.loader_params.as_dict(),
|
| 72 |
+
)
|
| 73 |
+
else:
|
| 74 |
+
train_loader = DataLoader(
|
| 75 |
+
dataset=train_dataset, **self.train_config.loader_params.as_dict()
|
| 76 |
+
)
|
| 77 |
+
# train_logger = Logger(**self.train_config.log.logger_params.as_dict())
|
| 78 |
+
|
| 79 |
+
max_epochs = self.train_config.max_epochs
|
| 80 |
+
batch_size = self.train_config.loader_params.batch_size
|
| 81 |
+
|
| 82 |
+
if self.val_log_together:
|
| 83 |
+
val_interval = self.train_config.log_and_val_interval
|
| 84 |
+
log_interval = val_interval
|
| 85 |
+
else:
|
| 86 |
+
val_interval = self.train_config.val_interval
|
| 87 |
+
log_interval = self.train_config.log.log_interval
|
| 88 |
+
|
| 89 |
+
if logger is None:
|
| 90 |
+
train_logger = Logger(**self.train_config.log.logger_params.as_dict())
|
| 91 |
+
else:
|
| 92 |
+
train_logger = logger
|
| 93 |
+
|
| 94 |
+
train_log_values = self.train_config.log.values.as_dict()
|
| 95 |
+
|
| 96 |
+
best_score = (
|
| 97 |
+
-math.inf if self.train_config.save_on.desired == "max" else math.inf
|
| 98 |
+
)
|
| 99 |
+
save_on_score = self.train_config.save_on.score
|
| 100 |
+
best_step = -1
|
| 101 |
+
best_model = None
|
| 102 |
+
|
| 103 |
+
best_hparam_list = None
|
| 104 |
+
best_hparam_name_list = None
|
| 105 |
+
best_metrics_list = None
|
| 106 |
+
best_metrics_name_list = None
|
| 107 |
+
|
| 108 |
+
# print("\nTraining\n")
|
| 109 |
+
# print(max_steps)
|
| 110 |
+
|
| 111 |
+
global_step = 0
|
| 112 |
+
for epoch in range(1, max_epochs + 1):
|
| 113 |
+
print(
|
| 114 |
+
"Epoch: {}/{}, Global Step: {}".format(epoch, max_epochs, global_step)
|
| 115 |
+
)
|
| 116 |
+
train_loss = 0
|
| 117 |
+
val_loss = 0
|
| 118 |
+
|
| 119 |
+
if(self.train_config.label_type=='float'):
|
| 120 |
+
all_labels = torch.FloatTensor().to(device)
|
| 121 |
+
else:
|
| 122 |
+
all_labels = torch.LongTensor().to(device)
|
| 123 |
+
|
| 124 |
+
all_outputs = torch.Tensor().to(device)
|
| 125 |
+
|
| 126 |
+
train_scores = None
|
| 127 |
+
val_scores = None
|
| 128 |
+
|
| 129 |
+
pbar = tqdm(total=math.ceil(len(train_dataset) / batch_size))
|
| 130 |
+
pbar.set_description("Epoch " + str(epoch))
|
| 131 |
+
|
| 132 |
+
val_counter = 0
|
| 133 |
+
|
| 134 |
+
for step, batch in enumerate(train_loader):
|
| 135 |
+
model.train()
|
| 136 |
+
optimizer.zero_grad()
|
| 137 |
+
inputs, labels = batch
|
| 138 |
+
|
| 139 |
+
if(self.train_config.label_type=='float'): ##Specific to Float Type
|
| 140 |
+
labels = labels.float()
|
| 141 |
+
|
| 142 |
+
for key in inputs:
|
| 143 |
+
inputs[key] = inputs[key].to(device)
|
| 144 |
+
labels = labels.to(device)
|
| 145 |
+
outputs = model(inputs)
|
| 146 |
+
loss = criterion(torch.squeeze(outputs), labels)
|
| 147 |
+
loss.backward()
|
| 148 |
+
|
| 149 |
+
all_labels = torch.cat((all_labels, labels), 0)
|
| 150 |
+
|
| 151 |
+
if (self.train_config.label_type=='float'):
|
| 152 |
+
all_outputs = torch.cat((all_outputs, outputs), 0)
|
| 153 |
+
else:
|
| 154 |
+
all_outputs = torch.cat((all_outputs, torch.argmax(outputs, axis=1)), 0)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
train_loss += loss.item()
|
| 158 |
+
optimizer.step()
|
| 159 |
+
|
| 160 |
+
if self.train_config.scheduler is not None:
|
| 161 |
+
if isinstance(scheduler, ReduceLROnPlateau):
|
| 162 |
+
scheduler.step(train_loss / (step + 1))
|
| 163 |
+
else:
|
| 164 |
+
scheduler.step()
|
| 165 |
+
|
| 166 |
+
# print(train_loss)
|
| 167 |
+
# print(step+1)
|
| 168 |
+
|
| 169 |
+
pbar.set_postfix_str(f"Train Loss: {train_loss /(step+1)}")
|
| 170 |
+
pbar.update(1)
|
| 171 |
+
|
| 172 |
+
global_step += 1
|
| 173 |
+
|
| 174 |
+
# Need to check if we want global_step or local_step
|
| 175 |
+
|
| 176 |
+
if val_dataset is not None and (global_step - 1) % val_interval == 0:
|
| 177 |
+
# print("\nEvaluating\n")
|
| 178 |
+
val_scores = self.val(
|
| 179 |
+
model,
|
| 180 |
+
val_dataset,
|
| 181 |
+
criterion,
|
| 182 |
+
device,
|
| 183 |
+
global_step,
|
| 184 |
+
train_logger,
|
| 185 |
+
train_log_values,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
#save_flag = 0
|
| 189 |
+
if self.train_config.save_on is not None:
|
| 190 |
+
|
| 191 |
+
## BEST SCORES UPDATING
|
| 192 |
+
|
| 193 |
+
train_scores = self.get_scores(
|
| 194 |
+
train_loss,
|
| 195 |
+
global_step,
|
| 196 |
+
self.train_config.criterion.type,
|
| 197 |
+
all_outputs,
|
| 198 |
+
all_labels,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
best_score, best_step, save_flag = self.check_best(
|
| 202 |
+
val_scores, save_on_score, best_score, global_step
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
store_dict = {
|
| 206 |
+
"model_state_dict": model.state_dict(),
|
| 207 |
+
"best_step": best_step,
|
| 208 |
+
"best_score": best_score,
|
| 209 |
+
"save_on_score": save_on_score,
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
path = self.train_config.save_on.best_path.format(
|
| 213 |
+
self.log_label
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
self.save(store_dict, path, save_flag)
|
| 217 |
+
|
| 218 |
+
if save_flag and train_log_values["hparams"] is not None:
|
| 219 |
+
(
|
| 220 |
+
best_hparam_list,
|
| 221 |
+
best_hparam_name_list,
|
| 222 |
+
best_metrics_list,
|
| 223 |
+
best_metrics_name_list,
|
| 224 |
+
) = self.update_hparams(
|
| 225 |
+
train_scores, val_scores, desc="best_val"
|
| 226 |
+
)
|
| 227 |
+
# pbar.close()
|
| 228 |
+
if (global_step - 1) % log_interval == 0:
|
| 229 |
+
# print("\nLogging\n")
|
| 230 |
+
train_loss_name = self.train_config.criterion.type
|
| 231 |
+
metric_list = [
|
| 232 |
+
metric(all_labels.cpu(), all_outputs.detach().cpu(), **self.metrics[metric])
|
| 233 |
+
for metric in self.metrics
|
| 234 |
+
]
|
| 235 |
+
metric_name_list = [
|
| 236 |
+
metric['type'] for metric in self._config.main_config.metrics
|
| 237 |
+
]
|
| 238 |
+
|
| 239 |
+
train_scores = self.log(
|
| 240 |
+
train_loss / (step + 1),
|
| 241 |
+
train_loss_name,
|
| 242 |
+
metric_list,
|
| 243 |
+
metric_name_list,
|
| 244 |
+
train_logger,
|
| 245 |
+
train_log_values,
|
| 246 |
+
global_step,
|
| 247 |
+
append_text=self.train_config.append_text,
|
| 248 |
+
)
|
| 249 |
+
pbar.close()
|
| 250 |
+
if not os.path.exists(self.train_config.checkpoint.checkpoint_dir):
|
| 251 |
+
os.makedirs(self.train_config.checkpoint.checkpoint_dir)
|
| 252 |
+
|
| 253 |
+
if self.train_config.save_after_epoch:
|
| 254 |
+
store_dict = {
|
| 255 |
+
"model_state_dict": model.state_dict(),
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
path = f"{self.train_config.checkpoint.checkpoint_dir}_{str(self.train_config.log.log_label)}_{str(epoch)}.pth"
|
| 259 |
+
|
| 260 |
+
self.save(store_dict, path, save_flag=1)
|
| 261 |
+
|
| 262 |
+
if epoch == max_epochs:
|
| 263 |
+
# print("\nEvaluating\n")
|
| 264 |
+
val_scores = self.val(
|
| 265 |
+
model,
|
| 266 |
+
val_dataset,
|
| 267 |
+
criterion,
|
| 268 |
+
device,
|
| 269 |
+
global_step,
|
| 270 |
+
train_logger,
|
| 271 |
+
train_log_values,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# print("\nLogging\n")
|
| 275 |
+
train_loss_name = self.train_config.criterion.type
|
| 276 |
+
metric_list = [
|
| 277 |
+
metric(all_labels.cpu(), all_outputs.detach().cpu(),**self.metrics[metric])
|
| 278 |
+
for metric in self.metrics
|
| 279 |
+
]
|
| 280 |
+
metric_name_list = [metric['type'] for metric in self._config.main_config.metrics]
|
| 281 |
+
|
| 282 |
+
train_scores = self.log(
|
| 283 |
+
train_loss / len(train_loader),
|
| 284 |
+
train_loss_name,
|
| 285 |
+
metric_list,
|
| 286 |
+
metric_name_list,
|
| 287 |
+
train_logger,
|
| 288 |
+
train_log_values,
|
| 289 |
+
global_step,
|
| 290 |
+
append_text=self.train_config.append_text,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
if self.train_config.save_on is not None:
|
| 294 |
+
|
| 295 |
+
## BEST SCORES UPDATING
|
| 296 |
+
|
| 297 |
+
train_scores = self.get_scores(
|
| 298 |
+
train_loss,
|
| 299 |
+
len(train_loader),
|
| 300 |
+
self.train_config.criterion.type,
|
| 301 |
+
all_outputs,
|
| 302 |
+
all_labels,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
best_score, best_step, save_flag = self.check_best(
|
| 306 |
+
val_scores, save_on_score, best_score, global_step
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
store_dict = {
|
| 310 |
+
"model_state_dict": model.state_dict(),
|
| 311 |
+
"best_step": best_step,
|
| 312 |
+
"best_score": best_score,
|
| 313 |
+
"save_on_score": save_on_score,
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
path = self.train_config.save_on.best_path.format(self.log_label)
|
| 317 |
+
|
| 318 |
+
self.save(store_dict, path, save_flag)
|
| 319 |
+
|
| 320 |
+
if save_flag and train_log_values["hparams"] is not None:
|
| 321 |
+
(
|
| 322 |
+
best_hparam_list,
|
| 323 |
+
best_hparam_name_list,
|
| 324 |
+
best_metrics_list,
|
| 325 |
+
best_metrics_name_list,
|
| 326 |
+
) = self.update_hparams(train_scores, val_scores, desc="best_val")
|
| 327 |
+
|
| 328 |
+
## FINAL SCORES UPDATING + STORING
|
| 329 |
+
train_scores = self.get_scores(
|
| 330 |
+
train_loss,
|
| 331 |
+
len(train_loader),
|
| 332 |
+
self.train_config.criterion.type,
|
| 333 |
+
all_outputs,
|
| 334 |
+
all_labels,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
store_dict = {
|
| 338 |
+
"model_state_dict": model.state_dict(),
|
| 339 |
+
"final_step": global_step,
|
| 340 |
+
"final_score": train_scores[save_on_score],
|
| 341 |
+
"save_on_score": save_on_score,
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
path = self.train_config.save_on.final_path.format(self.log_label)
|
| 345 |
+
|
| 346 |
+
self.save(store_dict, path, save_flag=1)
|
| 347 |
+
if train_log_values["hparams"] is not None:
|
| 348 |
+
(
|
| 349 |
+
final_hparam_list,
|
| 350 |
+
final_hparam_name_list,
|
| 351 |
+
final_metrics_list,
|
| 352 |
+
final_metrics_name_list,
|
| 353 |
+
) = self.update_hparams(train_scores, val_scores, desc="final")
|
| 354 |
+
train_logger.save_hyperparams(
|
| 355 |
+
best_hparam_list,
|
| 356 |
+
best_hparam_name_list,
|
| 357 |
+
[int(self.log_label),] + best_metrics_list + final_metrics_list,
|
| 358 |
+
["hparams/log_label",]
|
| 359 |
+
+ best_metrics_name_list
|
| 360 |
+
+ final_metrics_name_list,
|
| 361 |
+
)
|
| 362 |
+
#
|
| 363 |
+
|
| 364 |
+
## Need to check if we want same loggers of different loggers for train and eval
|
| 365 |
+
## Evaluate
|
| 366 |
+
|
| 367 |
+
def get_scores(self, loss, divisor, loss_name, all_outputs, all_labels):
|
| 368 |
+
|
| 369 |
+
avg_loss = loss / divisor
|
| 370 |
+
|
| 371 |
+
metric_list = [
|
| 372 |
+
metric(all_labels.cpu(), all_outputs.detach().cpu(), **self.metrics[metric])
|
| 373 |
+
for metric in self.metrics
|
| 374 |
+
]
|
| 375 |
+
metric_name_list = [metric['type'] for metric in self._config.main_config.metrics]
|
| 376 |
+
|
| 377 |
+
return dict(zip([loss_name,] + metric_name_list, [avg_loss,] + metric_list,))
|
| 378 |
+
|
| 379 |
+
def check_best(self, val_scores, save_on_score, best_score, global_step):
|
| 380 |
+
save_flag = 0
|
| 381 |
+
best_step = global_step
|
| 382 |
+
if self.train_config.save_on.desired == "min":
|
| 383 |
+
if val_scores[save_on_score] < best_score:
|
| 384 |
+
save_flag = 1
|
| 385 |
+
best_score = val_scores[save_on_score]
|
| 386 |
+
best_step = global_step
|
| 387 |
+
else:
|
| 388 |
+
if val_scores[save_on_score] > best_score:
|
| 389 |
+
save_flag = 1
|
| 390 |
+
best_score = val_scores[save_on_score]
|
| 391 |
+
best_step = global_step
|
| 392 |
+
return best_score, best_step, save_flag
|
| 393 |
+
|
| 394 |
+
def update_hparams(self, train_scores, val_scores, desc):
|
| 395 |
+
hparam_list = []
|
| 396 |
+
hparam_name_list = []
|
| 397 |
+
for hparam in self.train_config.log.values.hparams:
|
| 398 |
+
hparam_list.append(get_item_in_config(self._config, hparam["path"]))
|
| 399 |
+
if isinstance(hparam_list[-1], Config):
|
| 400 |
+
hparam_list[-1] = hparam_list[-1].as_dict()
|
| 401 |
+
hparam_name_list.append(hparam["name"])
|
| 402 |
+
|
| 403 |
+
val_keys, val_values = zip(*val_scores.items())
|
| 404 |
+
train_keys, train_values = zip(*train_scores.items())
|
| 405 |
+
val_keys = list(val_keys)
|
| 406 |
+
train_keys = list(train_keys)
|
| 407 |
+
val_values = list(val_values)
|
| 408 |
+
train_values = list(train_values)
|
| 409 |
+
for i, key in enumerate(val_keys):
|
| 410 |
+
val_keys[i] = f"hparams/{desc}_val_" + val_keys[i]
|
| 411 |
+
for i, key in enumerate(train_keys):
|
| 412 |
+
train_keys[i] = f"hparams/{desc}_train_" + train_keys[i]
|
| 413 |
+
# train_logger.save_hyperparams(hparam_list, hparam_name_list,train_values+val_values,train_keys+val_keys, )
|
| 414 |
+
return (
|
| 415 |
+
hparam_list,
|
| 416 |
+
hparam_name_list,
|
| 417 |
+
train_values + val_values,
|
| 418 |
+
train_keys + val_keys,
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
def save(self, store_dict, path, save_flag=0):
|
| 422 |
+
if save_flag:
|
| 423 |
+
dirs = "/".join(path.split("/")[:-1])
|
| 424 |
+
if not os.path.exists(dirs):
|
| 425 |
+
os.makedirs(dirs)
|
| 426 |
+
torch.save(store_dict, path)
|
| 427 |
+
|
| 428 |
+
def log(
|
| 429 |
+
self,
|
| 430 |
+
loss,
|
| 431 |
+
loss_name,
|
| 432 |
+
metric_list,
|
| 433 |
+
metric_name_list,
|
| 434 |
+
logger,
|
| 435 |
+
log_values,
|
| 436 |
+
global_step,
|
| 437 |
+
append_text,
|
| 438 |
+
):
|
| 439 |
+
|
| 440 |
+
return_dic = dict(zip([loss_name,] + metric_name_list, [loss,] + metric_list,))
|
| 441 |
+
|
| 442 |
+
loss_name = f"{append_text}_{self.log_label}_{loss_name}"
|
| 443 |
+
if log_values["loss"]:
|
| 444 |
+
logger.save_params(
|
| 445 |
+
[loss],
|
| 446 |
+
[loss_name],
|
| 447 |
+
combine=True,
|
| 448 |
+
combine_name="losses",
|
| 449 |
+
global_step=global_step,
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
for i in range(len(metric_name_list)):
|
| 453 |
+
metric_name_list[
|
| 454 |
+
i
|
| 455 |
+
] = f"{append_text}_{self.log_label}_{metric_name_list[i]}"
|
| 456 |
+
if log_values["metrics"]:
|
| 457 |
+
logger.save_params(
|
| 458 |
+
metric_list,
|
| 459 |
+
metric_name_list,
|
| 460 |
+
combine=True,
|
| 461 |
+
combine_name="metrics",
|
| 462 |
+
global_step=global_step,
|
| 463 |
+
)
|
| 464 |
+
# print(hparams_list)
|
| 465 |
+
# print(hparam_name_list)
|
| 466 |
+
|
| 467 |
+
# for k,v in dict(zip([loss_name],[loss])).items():
|
| 468 |
+
# print(f"{k}:{v}")
|
| 469 |
+
# for k,v in dict(zip(metric_name_list,metric_list)).items():
|
| 470 |
+
# print(f"{k}:{v}")
|
| 471 |
+
return return_dic
|
| 472 |
+
|
| 473 |
+
def val(
|
| 474 |
+
self,
|
| 475 |
+
model,
|
| 476 |
+
dataset,
|
| 477 |
+
criterion,
|
| 478 |
+
device,
|
| 479 |
+
global_step,
|
| 480 |
+
train_logger=None,
|
| 481 |
+
train_log_values=None,
|
| 482 |
+
log=True,
|
| 483 |
+
):
|
| 484 |
+
append_text = self.val_config.append_text
|
| 485 |
+
if train_logger is not None:
|
| 486 |
+
val_logger = train_logger
|
| 487 |
+
else:
|
| 488 |
+
val_logger = Logger(**self.val_config.log.logger_params.as_dict())
|
| 489 |
+
|
| 490 |
+
if train_log_values is not None:
|
| 491 |
+
val_log_values = train_log_values
|
| 492 |
+
else:
|
| 493 |
+
val_log_values = self.val_config.log.values.as_dict()
|
| 494 |
+
if "custom_collate_fn" in dir(dataset):
|
| 495 |
+
val_loader = DataLoader(
|
| 496 |
+
dataset=dataset,
|
| 497 |
+
collate_fn=dataset.custom_collate_fn,
|
| 498 |
+
**self.val_config.loader_params.as_dict(),
|
| 499 |
+
)
|
| 500 |
+
else:
|
| 501 |
+
val_loader = DataLoader(
|
| 502 |
+
dataset=dataset, **self.val_config.loader_params.as_dict()
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
all_outputs = torch.Tensor().to(device)
|
| 506 |
+
if(self.train_config.label_type=='float'):
|
| 507 |
+
all_labels = torch.FloatTensor().to(device)
|
| 508 |
+
else:
|
| 509 |
+
all_labels = torch.LongTensor().to(device)
|
| 510 |
+
|
| 511 |
+
batch_size = self.val_config.loader_params.batch_size
|
| 512 |
+
|
| 513 |
+
with torch.no_grad():
|
| 514 |
+
model.eval()
|
| 515 |
+
val_loss = 0
|
| 516 |
+
for j, batch in enumerate(val_loader):
|
| 517 |
+
|
| 518 |
+
inputs, labels = batch
|
| 519 |
+
|
| 520 |
+
if(self.train_config.label_type=='float'):
|
| 521 |
+
labels = labels.float()
|
| 522 |
+
|
| 523 |
+
for key in inputs:
|
| 524 |
+
inputs[key] = inputs[key].to(device)
|
| 525 |
+
labels = labels.to(device)
|
| 526 |
+
|
| 527 |
+
outputs = model(inputs)
|
| 528 |
+
loss = criterion(torch.squeeze(outputs), labels)
|
| 529 |
+
val_loss += loss.item()
|
| 530 |
+
|
| 531 |
+
all_labels = torch.cat((all_labels, labels), 0)
|
| 532 |
+
|
| 533 |
+
if (self.train_config.label_type=='float'):
|
| 534 |
+
all_outputs = torch.cat((all_outputs, outputs), 0)
|
| 535 |
+
else:
|
| 536 |
+
all_outputs = torch.cat((all_outputs, torch.argmax(outputs, axis=1)), 0)
|
| 537 |
+
|
| 538 |
+
val_loss = val_loss / len(val_loader)
|
| 539 |
+
|
| 540 |
+
val_loss_name = self.train_config.criterion.type
|
| 541 |
+
|
| 542 |
+
# print(all_outputs, all_labels)
|
| 543 |
+
metric_list = [
|
| 544 |
+
metric(all_labels.cpu(), all_outputs.detach().cpu(), **self.metrics[metric])
|
| 545 |
+
for metric in self.metrics
|
| 546 |
+
]
|
| 547 |
+
metric_name_list = [metric['type'] for metric in self._config.main_config.metrics]
|
| 548 |
+
return_dic = dict(
|
| 549 |
+
zip([val_loss_name,] + metric_name_list, [val_loss,] + metric_list,)
|
| 550 |
+
)
|
| 551 |
+
if log:
|
| 552 |
+
val_scores = self.log(
|
| 553 |
+
val_loss,
|
| 554 |
+
val_loss_name,
|
| 555 |
+
metric_list,
|
| 556 |
+
metric_name_list,
|
| 557 |
+
val_logger,
|
| 558 |
+
val_log_values,
|
| 559 |
+
global_step,
|
| 560 |
+
append_text,
|
| 561 |
+
)
|
| 562 |
+
return val_scores
|
| 563 |
+
return return_dic
|
src/utils/__init__.py
ADDED
|
File without changes
|
src/utils/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (164 Bytes). View file
|
|
|