|
|
from collections import Counter |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import random |
|
|
from typing import List, Tuple, Any, Mapping |
|
|
|
|
|
import stanza |
|
|
import torch |
|
|
|
|
|
from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE |
|
|
|
|
|
logger = logging.getLogger('stanza.lemmaclassifier') |
|
|
|
|
|
class Dataset: |
|
|
def __init__(self, data_path: str, batch_size: int =DEFAULT_BATCH_SIZE, get_counts: bool = False, label_decoder: dict = None, shuffle: bool = True): |
|
|
""" |
|
|
Loads a data file into data batches for tokenized text sentences, token indices, and true labels for each sentence. |
|
|
|
|
|
Args: |
|
|
data_path (str): Path to data file, containing tokenized text sentences, token index and true label for token lemma on each line. |
|
|
batch_size (int): Size of each batch of examples |
|
|
get_counts (optional, bool): Whether there should be a map of the label index to counts |
|
|
|
|
|
Returns: |
|
|
1. List[List[List[str]]]: Batches of sentences, where each token is a separate entry in each sentence |
|
|
2. List[torch.tensor[int]]: A batch of indexes for the target token corresponding to its sentence |
|
|
3. List[torch.tensor[int]]: A batch of labels for the target token's lemma |
|
|
4. List[List[int]]: A batch of UPOS IDs for the target token (this is a List of Lists, not a tensor. It should be padded later.) |
|
|
5 (Optional): A mapping of label ID to counts in the dataset. |
|
|
6. Mapping[str, int]: A map between the labels and their indexes |
|
|
7. Mapping[str, int]: A map between the UPOS tags and their corresponding IDs found in the UPOS batches |
|
|
""" |
|
|
|
|
|
if data_path is None or not os.path.exists(data_path): |
|
|
raise FileNotFoundError(f"Data file {data_path} could not be found.") |
|
|
|
|
|
if label_decoder is None: |
|
|
label_decoder = {} |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
label_decoder = dict(label_decoder) |
|
|
|
|
|
logger.debug("Final label decoder: %s Should be strings to ints", label_decoder) |
|
|
|
|
|
|
|
|
target_words = set() |
|
|
|
|
|
|
|
|
known_words = set() |
|
|
|
|
|
with open(data_path, "r+", encoding="utf-8") as fin: |
|
|
sentences, indices, labels, upos_ids, counts, upos_to_id = [], [], [], [], Counter(), {} |
|
|
|
|
|
input_json = json.load(fin) |
|
|
sentences_data = input_json['sentences'] |
|
|
self.target_upos = input_json['upos'] |
|
|
|
|
|
for idx, sentence in enumerate(sentences_data): |
|
|
|
|
|
words, target_idx, upos_tags, label = sentence.get("words"), sentence.get("index"), sentence.get("upos_tags"), sentence.get("lemma") |
|
|
if None in [words, target_idx, upos_tags, label]: |
|
|
raise ValueError(f"Expected data to be complete but found a null value in sentence {idx}: {sentence}") |
|
|
|
|
|
label_id = label_decoder.get(label, None) |
|
|
if label_id is None: |
|
|
label_decoder[label] = len(label_decoder) |
|
|
|
|
|
converted_upos_tags = [] |
|
|
for upos_tag in upos_tags: |
|
|
if upos_tag not in upos_to_id: |
|
|
upos_to_id[upos_tag] = len(upos_to_id) |
|
|
converted_upos_tags.append(upos_to_id[upos_tag]) |
|
|
|
|
|
sentences.append(words) |
|
|
indices.append(target_idx) |
|
|
upos_ids.append(converted_upos_tags) |
|
|
labels.append(label_decoder[label]) |
|
|
|
|
|
if get_counts: |
|
|
counts[label_decoder[label]] += 1 |
|
|
|
|
|
target_words.add(words[target_idx]) |
|
|
known_words.update(words) |
|
|
|
|
|
self.sentences = sentences |
|
|
self.indices = indices |
|
|
self.upos_ids = upos_ids |
|
|
self.labels = labels |
|
|
|
|
|
self.counts = counts |
|
|
self.label_decoder = label_decoder |
|
|
self.upos_to_id = upos_to_id |
|
|
|
|
|
self.batch_size = batch_size |
|
|
self.shuffle = shuffle |
|
|
|
|
|
self.known_words = [x.lower() for x in sorted(known_words)] |
|
|
self.target_words = set(x.lower() for x in target_words) |
|
|
|
|
|
def __len__(self): |
|
|
""" |
|
|
Number of batches, rounded up to nearest batch |
|
|
""" |
|
|
return len(self.sentences) // self.batch_size + (len(self.sentences) % self.batch_size > 0) |
|
|
|
|
|
def __iter__(self): |
|
|
num_sentences = len(self.sentences) |
|
|
indices = list(range(num_sentences)) |
|
|
if self.shuffle: |
|
|
random.shuffle(indices) |
|
|
for i in range(self.__len__()): |
|
|
batch_start = self.batch_size * i |
|
|
batch_end = min(batch_start + self.batch_size, num_sentences) |
|
|
|
|
|
batch_sentences = [self.sentences[x] for x in indices[batch_start:batch_end]] |
|
|
batch_indices = torch.tensor([self.indices[x] for x in indices[batch_start:batch_end]]) |
|
|
batch_upos_ids = [self.upos_ids[x] for x in indices[batch_start:batch_end]] |
|
|
batch_labels = torch.tensor([self.labels[x] for x in indices[batch_start:batch_end]]) |
|
|
yield batch_sentences, batch_indices, batch_upos_ids, batch_labels |
|
|
|
|
|
def extract_unknown_token_indices(tokenized_indices: torch.tensor, unknown_token_idx: int) -> List[int]: |
|
|
""" |
|
|
Extracts the indices within `tokenized_indices` which match `unknown_token_idx` |
|
|
|
|
|
Args: |
|
|
tokenized_indices (torch.tensor): A tensor filled with tokenized indices of words that have been mapped to vector indices. |
|
|
unknown_token_idx (int): The special index for which unknown tokens are marked in the word vectors. |
|
|
|
|
|
Returns: |
|
|
List[int]: A list of indices in `tokenized_indices` which match `unknown_token_index` |
|
|
""" |
|
|
return [idx for idx, token_index in enumerate(tokenized_indices) if token_index == unknown_token_idx] |
|
|
|
|
|
|
|
|
def get_device(): |
|
|
""" |
|
|
Get the device to run computations on |
|
|
""" |
|
|
if torch.cuda.is_available: |
|
|
device = torch.device("cuda") |
|
|
if torch.backends.mps.is_available(): |
|
|
device = torch.device("mps") |
|
|
else: |
|
|
device = torch.device("cpu") |
|
|
|
|
|
return device |
|
|
|
|
|
|
|
|
def round_up_to_multiple(number, multiple): |
|
|
if multiple == 0: |
|
|
return "Error: The second number (multiple) cannot be zero." |
|
|
|
|
|
|
|
|
remainder = number % multiple |
|
|
|
|
|
|
|
|
if remainder != 0: |
|
|
rounded_number = number + (multiple - remainder) |
|
|
else: |
|
|
rounded_number = number |
|
|
|
|
|
return rounded_number |
|
|
|
|
|
|
|
|
def main(): |
|
|
default_test_path = os.path.join(os.path.dirname(__file__), "test_sets", "processed_ud_en", "combined_dev.txt") |
|
|
sentence_batches, indices_batches, upos_batches, _, counts, _, upos_to_id = load_dataset(default_test_path, get_counts=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|