| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import array |
| | import pickle as pkl |
| | from typing import Optional |
| |
|
| | import torch |
| |
|
| | from nemo.collections.nlp.data.data_utils.data_preprocessing import find_newlines, load_data_indices |
| | from nemo.core.classes import Dataset |
| | from nemo.utils import logging |
| |
|
| | __all__ = ['EntityLinkingDataset'] |
| |
|
| |
|
| | class EntityLinkingDataset(Dataset): |
| | """ |
| | Parent class for entity linking encoder training and index |
| | datasets |
| | |
| | Args: |
| | tokenizer (obj): huggingface tokenizer, |
| | data_file (str): path to tab separated column file where data |
| | pairs apear in the format |
| | concept_ID\tconcept_synonym1\tconcept_synonym2\n |
| | newline_idx_file (str): path to pickle file containing location |
| | of data_file newline characters |
| | max_seq_length (int): maximum length of a concept in tokens |
| | is_index_data (bool): Whether dataset will be used for building |
| | a nearest neighbors index |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | tokenizer: object, |
| | data_file: str, |
| | newline_idx_file: Optional[str] = None, |
| | max_seq_length: Optional[int] = 512, |
| | is_index_data: bool = False, |
| | ): |
| |
|
| | self.tokenizer = tokenizer |
| |
|
| | |
| | newline_indices, newline_idx_file, _ = load_data_indices(newline_idx_file, data_file, "newline_indices") |
| |
|
| | |
| | if newline_indices is None: |
| | logging.info("Getting datafile newline indices") |
| |
|
| | with open(data_file, "rb") as f: |
| | contents = f.read() |
| | newline_indices = find_newlines(contents) |
| | newline_indices = array.array("I", newline_indices) |
| |
|
| | |
| | with open(newline_idx_file, "wb") as f: |
| | pkl.dump(newline_indices, f) |
| |
|
| | self.newline_indices = newline_indices |
| | self.data_file = data_file |
| | self.num_lines = len(newline_indices) |
| | self.max_seq_length = max_seq_length |
| | self.is_index_data = is_index_data |
| |
|
| | logging.info(f"Loaded dataset with {self.num_lines} examples") |
| |
|
| | def __len__(self): |
| | return self.num_lines |
| |
|
| | def __getitem__(self, idx): |
| |
|
| | concept_offset = self.newline_indices[idx] |
| |
|
| | with open(self.data_file, "r", encoding='utf-8-sig') as f: |
| | |
| | f.seek(concept_offset) |
| | concept = f.readline()[:-1] |
| | concept = concept.strip().split("\t") |
| |
|
| | if self.is_index_data: |
| | concept_id, concept = concept |
| | return (int(concept_id), concept) |
| |
|
| | else: |
| | concept_id, concept1, concept2 = concept |
| | return (int(concept_id), concept1, concept2) |
| |
|
| | def _collate_fn(self, batch): |
| | """collate batch of input_ids, segment_ids, input_mask, and label |
| | |
| | Args: |
| | batch: A list of tuples of format (concept_ID, concept_synonym1, concept_synonym2). |
| | """ |
| | if self.is_index_data: |
| | concept_ids, concepts = zip(*batch) |
| | concept_ids = list(concept_ids) |
| | concepts = list(concepts) |
| |
|
| | else: |
| | concept_ids, concepts1, concepts2 = zip(*batch) |
| | concept_ids = list(concept_ids) |
| | concept_ids.extend(concept_ids) |
| | concepts = list(concepts1) |
| | concepts.extend(concepts2) |
| |
|
| | batch = self.tokenizer( |
| | concepts, |
| | add_special_tokens=True, |
| | padding=True, |
| | truncation=True, |
| | max_length=self.max_seq_length, |
| | return_token_type_ids=True, |
| | return_attention_mask=True, |
| | return_length=True, |
| | ) |
| |
|
| | return ( |
| | torch.LongTensor(batch["input_ids"]), |
| | torch.LongTensor(batch["token_type_ids"]), |
| | torch.LongTensor(batch["attention_mask"]), |
| | torch.LongTensor(concept_ids), |
| | ) |
| |
|