File size: 4,723 Bytes
7934b29 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import array
import pickle as pkl
from typing import Optional
import torch
from nemo.collections.nlp.data.data_utils.data_preprocessing import find_newlines, load_data_indices
from nemo.core.classes import Dataset
from nemo.utils import logging
__all__ = ['EntityLinkingDataset']
class EntityLinkingDataset(Dataset):
"""
Parent class for entity linking encoder training and index
datasets
Args:
tokenizer (obj): huggingface tokenizer,
data_file (str): path to tab separated column file where data
pairs apear in the format
concept_ID\tconcept_synonym1\tconcept_synonym2\n
newline_idx_file (str): path to pickle file containing location
of data_file newline characters
max_seq_length (int): maximum length of a concept in tokens
is_index_data (bool): Whether dataset will be used for building
a nearest neighbors index
"""
def __init__(
self,
tokenizer: object,
data_file: str,
newline_idx_file: Optional[str] = None,
max_seq_length: Optional[int] = 512,
is_index_data: bool = False,
):
self.tokenizer = tokenizer
# Try and load pair indices file if already exists
newline_indices, newline_idx_file, _ = load_data_indices(newline_idx_file, data_file, "newline_indices")
# If pair indices file doesn't exists, generate and store them
if newline_indices is None:
logging.info("Getting datafile newline indices")
with open(data_file, "rb") as f:
contents = f.read()
newline_indices = find_newlines(contents)
newline_indices = array.array("I", newline_indices)
# Store data file indicies to avoid generating them again
with open(newline_idx_file, "wb") as f:
pkl.dump(newline_indices, f)
self.newline_indices = newline_indices
self.data_file = data_file
self.num_lines = len(newline_indices)
self.max_seq_length = max_seq_length
self.is_index_data = is_index_data
logging.info(f"Loaded dataset with {self.num_lines} examples")
def __len__(self):
return self.num_lines
def __getitem__(self, idx):
concept_offset = self.newline_indices[idx]
with open(self.data_file, "r", encoding='utf-8-sig') as f:
# Find data pair within datafile using byte offset
f.seek(concept_offset)
concept = f.readline()[:-1]
concept = concept.strip().split("\t")
if self.is_index_data:
concept_id, concept = concept
return (int(concept_id), concept)
else:
concept_id, concept1, concept2 = concept
return (int(concept_id), concept1, concept2)
def _collate_fn(self, batch):
"""collate batch of input_ids, segment_ids, input_mask, and label
Args:
batch: A list of tuples of format (concept_ID, concept_synonym1, concept_synonym2).
"""
if self.is_index_data:
concept_ids, concepts = zip(*batch)
concept_ids = list(concept_ids)
concepts = list(concepts)
else:
concept_ids, concepts1, concepts2 = zip(*batch)
concept_ids = list(concept_ids)
concept_ids.extend(concept_ids) # Need to double label list to match each concept
concepts = list(concepts1)
concepts.extend(concepts2)
batch = self.tokenizer(
concepts,
add_special_tokens=True,
padding=True,
truncation=True,
max_length=self.max_seq_length,
return_token_type_ids=True,
return_attention_mask=True,
return_length=True,
)
return (
torch.LongTensor(batch["input_ids"]),
torch.LongTensor(batch["token_type_ids"]),
torch.LongTensor(batch["attention_mask"]),
torch.LongTensor(concept_ids),
)
|