|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Utility functions for Token Classification NLP tasks |
|
|
Some parts of this code were adapted from the HuggingFace library at |
|
|
https://github.com/huggingface/pytorch-pretrained-BERT |
|
|
""" |
|
|
|
|
|
import os |
|
|
import pickle |
|
|
import tempfile |
|
|
import time |
|
|
from typing import Dict, List, Optional |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec |
|
|
from nemo.collections.nlp.data.data_utils.data_preprocessing import get_stats |
|
|
from nemo.core.classes import Dataset |
|
|
from nemo.core.neural_types import ChannelType, LabelsType, MaskType, NeuralType |
|
|
from nemo.utils import logging |
|
|
from nemo.utils.get_rank import is_global_rank_zero |
|
|
|
|
|
__all__ = ['BertTokenClassificationDataset', 'BertTokenClassificationInferDataset'] |
|
|
|
|
|
|
|
|
def get_features( |
|
|
queries: List[str], |
|
|
tokenizer: TokenizerSpec, |
|
|
max_seq_length: int = -1, |
|
|
label_ids: dict = None, |
|
|
pad_label: str = 'O', |
|
|
raw_labels: List[str] = None, |
|
|
ignore_extra_tokens: bool = False, |
|
|
ignore_start_end: bool = False, |
|
|
): |
|
|
""" |
|
|
Processes the data and returns features. |
|
|
Args: |
|
|
queries: text sequences |
|
|
tokenizer: such as AutoTokenizer |
|
|
max_seq_length: max sequence length minus 2 for [CLS] and [SEP], when -1 - use the max len from the data |
|
|
pad_label: pad value use for labels. By default, it's the neutral label. |
|
|
raw_labels: list of labels for every word in a sequence |
|
|
label_ids: dict to map labels to label ids. |
|
|
Starts with pad_label->0 and then increases in alphabetical order. |
|
|
Required for training and evaluation, not needed for inference. |
|
|
ignore_extra_tokens: whether to ignore extra tokens in the loss_mask |
|
|
ignore_start_end: whether to ignore bos and eos tokens in the loss_mask |
|
|
""" |
|
|
all_subtokens = [] |
|
|
all_loss_mask = [] |
|
|
all_subtokens_mask = [] |
|
|
all_segment_ids = [] |
|
|
all_input_ids = [] |
|
|
all_input_mask = [] |
|
|
sent_lengths = [] |
|
|
all_labels = [] |
|
|
with_label = False |
|
|
|
|
|
if raw_labels is not None: |
|
|
with_label = True |
|
|
|
|
|
for i, query in enumerate(queries): |
|
|
words = query.strip().split() |
|
|
|
|
|
|
|
|
subtokens = [tokenizer.cls_token] |
|
|
loss_mask = [1 - ignore_start_end] |
|
|
subtokens_mask = [0] |
|
|
if with_label: |
|
|
pad_id = label_ids[pad_label] |
|
|
labels = [pad_id] |
|
|
query_labels = [label_ids[lab] for lab in raw_labels[i]] |
|
|
|
|
|
for j, word in enumerate(words): |
|
|
word_tokens = tokenizer.text_to_tokens(word) |
|
|
|
|
|
|
|
|
if len(word.strip()) > 0 and len(word_tokens) == 0: |
|
|
word_tokens = [tokenizer.ids_to_tokens(tokenizer.unk_id)] |
|
|
|
|
|
subtokens.extend(word_tokens) |
|
|
|
|
|
loss_mask.append(1) |
|
|
loss_mask.extend([int(not ignore_extra_tokens)] * (len(word_tokens) - 1)) |
|
|
|
|
|
subtokens_mask.append(1) |
|
|
subtokens_mask.extend([0] * (len(word_tokens) - 1)) |
|
|
|
|
|
if with_label: |
|
|
labels.extend([query_labels[j]] * len(word_tokens)) |
|
|
|
|
|
subtokens.append(tokenizer.sep_token) |
|
|
loss_mask.append(1 - ignore_start_end) |
|
|
subtokens_mask.append(0) |
|
|
sent_lengths.append(len(subtokens)) |
|
|
all_subtokens.append(subtokens) |
|
|
all_loss_mask.append(loss_mask) |
|
|
all_subtokens_mask.append(subtokens_mask) |
|
|
all_input_mask.append([1] * len(subtokens)) |
|
|
|
|
|
if with_label: |
|
|
labels.append(pad_id) |
|
|
all_labels.append(labels) |
|
|
|
|
|
max_seq_length_data = max(sent_lengths) |
|
|
max_seq_length = min(max_seq_length, max_seq_length_data) if max_seq_length > 0 else max_seq_length_data |
|
|
logging.info(f'Setting Max Seq length to: {max_seq_length}') |
|
|
get_stats(sent_lengths) |
|
|
too_long_count = 0 |
|
|
|
|
|
for i, subtokens in enumerate(all_subtokens): |
|
|
if len(subtokens) > max_seq_length: |
|
|
subtokens = [tokenizer.cls_token] + subtokens[-max_seq_length + 1 :] |
|
|
all_input_mask[i] = [1] + all_input_mask[i][-max_seq_length + 1 :] |
|
|
all_loss_mask[i] = [int(not ignore_start_end)] + all_loss_mask[i][-max_seq_length + 1 :] |
|
|
all_subtokens_mask[i] = [0] + all_subtokens_mask[i][-max_seq_length + 1 :] |
|
|
|
|
|
if with_label: |
|
|
all_labels[i] = [pad_id] + all_labels[i][-max_seq_length + 1 :] |
|
|
too_long_count += 1 |
|
|
|
|
|
all_input_ids.append(tokenizer.tokens_to_ids(subtokens)) |
|
|
|
|
|
if len(subtokens) < max_seq_length: |
|
|
extra = max_seq_length - len(subtokens) |
|
|
all_input_ids[i] = all_input_ids[i] + [0] * extra |
|
|
all_loss_mask[i] = all_loss_mask[i] + [0] * extra |
|
|
all_subtokens_mask[i] = all_subtokens_mask[i] + [0] * extra |
|
|
all_input_mask[i] = all_input_mask[i] + [0] * extra |
|
|
|
|
|
if with_label: |
|
|
all_labels[i] = all_labels[i] + [pad_id] * extra |
|
|
|
|
|
all_segment_ids.append([0] * max_seq_length) |
|
|
|
|
|
logging.warning(f'{too_long_count} are longer than {max_seq_length}') |
|
|
|
|
|
for i in range(min(len(all_input_ids), 1)): |
|
|
logging.info("*** Example ***") |
|
|
logging.info("i: %s", i) |
|
|
logging.info("subtokens: %s", " ".join(list(map(str, all_subtokens[i])))) |
|
|
logging.info("loss_mask: %s", " ".join(list(map(str, all_loss_mask[i])))) |
|
|
logging.info("input_mask: %s", " ".join(list(map(str, all_input_mask[i])))) |
|
|
logging.info("subtokens_mask: %s", " ".join(list(map(str, all_subtokens_mask[i])))) |
|
|
if with_label: |
|
|
logging.info("labels: %s", " ".join(list(map(str, all_labels[i])))) |
|
|
return (all_input_ids, all_segment_ids, all_input_mask, all_subtokens_mask, all_loss_mask, all_labels) |
|
|
|
|
|
|
|
|
class BertTokenClassificationDataset(Dataset): |
|
|
""" |
|
|
Creates dataset to use during training for token classification tasks with a pretrained model. |
|
|
|
|
|
Converts from raw data to an instance that can be used by Dataloader. |
|
|
For dataset to use during inference without labels, see BertTokenClassificationInferDataset. |
|
|
|
|
|
Args: |
|
|
text_file: file to sequences, each line should a sentence, no header. |
|
|
label_file: file to labels, each line corresponds to word labels for a sentence in the text_file. No header. |
|
|
max_seq_length: max sequence length minus 2 for [CLS] and [SEP] |
|
|
tokenizer: such as AutoTokenizer |
|
|
num_samples: number of samples you want to use for the dataset. |
|
|
If -1, use all dataset. Useful for testing. |
|
|
pad_label: pad value use for labels. By default, it's the neutral label. |
|
|
label_ids: label_ids (dict): dict to map labels to label ids. |
|
|
Starts with pad_label->0 and then increases in alphabetical order |
|
|
For dev set use label_ids generated during training to support |
|
|
cases when not all labels are present in the dev set. |
|
|
For training set label_ids should be None. |
|
|
ignore_extra_tokens: whether to ignore extra tokens in the loss_mask |
|
|
ignore_start_end: whether to ignore bos and eos tokens in the loss_mask |
|
|
use_cache: whether to use processed data cache or not |
|
|
""" |
|
|
|
|
|
@property |
|
|
def output_types(self) -> Optional[Dict[str, NeuralType]]: |
|
|
"""Returns definitions of module output ports. |
|
|
""" |
|
|
return { |
|
|
'input_ids': NeuralType(('B', 'T'), ChannelType()), |
|
|
'segment_ids': NeuralType(('B', 'T'), ChannelType()), |
|
|
'input_mask': NeuralType(('B', 'T'), MaskType()), |
|
|
'subtokens_mask': NeuralType(('B', 'T'), MaskType()), |
|
|
'loss_mask': NeuralType(('B', 'T'), MaskType()), |
|
|
'labels': NeuralType(('B', 'T'), LabelsType()), |
|
|
} |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
text_file: str, |
|
|
label_file: str, |
|
|
max_seq_length: int, |
|
|
tokenizer: TokenizerSpec, |
|
|
num_samples: int = -1, |
|
|
pad_label: str = 'O', |
|
|
label_ids: Dict[str, int] = None, |
|
|
ignore_extra_tokens: bool = False, |
|
|
ignore_start_end: bool = False, |
|
|
use_cache: bool = True, |
|
|
): |
|
|
""" Initializes BertTokenClassificationDataset. """ |
|
|
|
|
|
data_dir = os.path.dirname(text_file) |
|
|
text_filename = os.path.basename(text_file) |
|
|
lbl_filename = os.path.basename(label_file) |
|
|
|
|
|
if not text_filename.endswith('.txt'): |
|
|
raise ValueError("{text_file} should have extension .txt") |
|
|
|
|
|
vocab_size = getattr(tokenizer, "vocab_size", 0) |
|
|
features_pkl = os.path.join( |
|
|
data_dir, |
|
|
f"cached__{text_filename}__{lbl_filename}__{tokenizer.name}_{max_seq_length}_{vocab_size}_{num_samples}", |
|
|
) |
|
|
|
|
|
master_device = is_global_rank_zero() |
|
|
features = None |
|
|
if master_device and (not use_cache or not os.path.exists(features_pkl)): |
|
|
if num_samples == 0: |
|
|
raise ValueError("num_samples has to be positive", num_samples) |
|
|
|
|
|
with open(text_file, 'r') as f: |
|
|
text_lines = f.readlines() |
|
|
|
|
|
labels_lines = [] |
|
|
with open(label_file, 'r') as f: |
|
|
for line in f: |
|
|
line = line.strip().split() |
|
|
labels_lines.append(line) |
|
|
|
|
|
if len(labels_lines) != len(text_lines): |
|
|
raise ValueError("Labels file should contain labels for every word") |
|
|
|
|
|
if num_samples > 0: |
|
|
dataset = list(zip(text_lines, labels_lines)) |
|
|
dataset = dataset[:num_samples] |
|
|
|
|
|
dataset = list(zip(*dataset)) |
|
|
text_lines = dataset[0] |
|
|
labels_lines = dataset[1] |
|
|
|
|
|
features = get_features( |
|
|
queries=text_lines, |
|
|
max_seq_length=max_seq_length, |
|
|
tokenizer=tokenizer, |
|
|
pad_label=pad_label, |
|
|
raw_labels=labels_lines, |
|
|
label_ids=label_ids, |
|
|
ignore_extra_tokens=ignore_extra_tokens, |
|
|
ignore_start_end=ignore_start_end, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
ofd, tmp_features_pkl = tempfile.mkstemp( |
|
|
suffix='.pkl', prefix=os.path.basename(features_pkl), dir=os.path.dirname(features_pkl) |
|
|
) |
|
|
with os.fdopen(ofd, 'wb') as temp_f: |
|
|
pickle.dump(features, temp_f) |
|
|
|
|
|
os.rename(tmp_features_pkl, features_pkl) |
|
|
logging.info(f'features saved to {features_pkl}') |
|
|
|
|
|
|
|
|
if not master_device: |
|
|
while features is None and not os.path.exists(features_pkl): |
|
|
time.sleep(10) |
|
|
|
|
|
if features is None: |
|
|
features = pickle.load(open(features_pkl, 'rb')) |
|
|
logging.info(f'features restored from {features_pkl}') |
|
|
|
|
|
self.all_input_ids = features[0] |
|
|
self.all_segment_ids = features[1] |
|
|
self.all_input_mask = features[2] |
|
|
self.all_subtokens_mask = features[3] |
|
|
self.all_loss_mask = features[4] |
|
|
self.all_labels = features[5] |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.all_input_ids) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
return ( |
|
|
np.array(self.all_input_ids[idx]), |
|
|
np.array(self.all_segment_ids[idx]), |
|
|
np.array(self.all_input_mask[idx], dtype=np.long), |
|
|
np.array(self.all_subtokens_mask[idx]), |
|
|
np.array(self.all_loss_mask[idx]), |
|
|
np.array(self.all_labels[idx]), |
|
|
) |
|
|
|
|
|
|
|
|
class BertTokenClassificationInferDataset(Dataset): |
|
|
""" |
|
|
Creates dataset to use during inference for token classification tasks with a pretrained model. |
|
|
For dataset to use during training with labels, see BertTokenClassificationDataset. |
|
|
""" |
|
|
|
|
|
@property |
|
|
def output_types(self) -> Optional[Dict[str, NeuralType]]: |
|
|
"""Returns definitions of module output ports. |
|
|
""" |
|
|
return { |
|
|
'input_ids': NeuralType(('B', 'T'), ChannelType()), |
|
|
'segment_ids': NeuralType(('B', 'T'), ChannelType()), |
|
|
'input_mask': NeuralType(('B', 'T'), MaskType()), |
|
|
'subtokens_mask': NeuralType(('B', 'T'), MaskType()), |
|
|
} |
|
|
|
|
|
def __init__( |
|
|
self, queries: List[str], max_seq_length: int, tokenizer: TokenizerSpec, |
|
|
): |
|
|
""" |
|
|
Initializes BertTokenClassificationInferDataset |
|
|
Args: |
|
|
queries: text sequences |
|
|
max_seq_length: max sequence length minus 2 for [CLS] and [SEP] |
|
|
tokenizer: such as AutoTokenizer |
|
|
""" |
|
|
features = get_features(queries=queries, max_seq_length=max_seq_length, tokenizer=tokenizer) |
|
|
|
|
|
self.all_input_ids = features[0] |
|
|
self.all_segment_ids = features[1] |
|
|
self.all_input_mask = features[2] |
|
|
self.all_subtokens_mask = features[3] |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.all_input_ids) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
return ( |
|
|
np.array(self.all_input_ids[idx]), |
|
|
np.array(self.all_segment_ids[idx]), |
|
|
np.array(self.all_input_mask[idx], dtype=np.long), |
|
|
np.array(self.all_subtokens_mask[idx]), |
|
|
) |
|
|
|