Albin Thörn Cleland
Clean initial commit with LFS
19b8775
from collections import Counter
import json
import logging
import os
import random
from typing import List, Tuple, Any, Mapping
import stanza
import torch
from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
logger = logging.getLogger('stanza.lemmaclassifier')
class Dataset:
def __init__(self, data_path: str, batch_size: int =DEFAULT_BATCH_SIZE, get_counts: bool = False, label_decoder: dict = None, shuffle: bool = True):
"""
Loads a data file into data batches for tokenized text sentences, token indices, and true labels for each sentence.
Args:
data_path (str): Path to data file, containing tokenized text sentences, token index and true label for token lemma on each line.
batch_size (int): Size of each batch of examples
get_counts (optional, bool): Whether there should be a map of the label index to counts
Returns:
1. List[List[List[str]]]: Batches of sentences, where each token is a separate entry in each sentence
2. List[torch.tensor[int]]: A batch of indexes for the target token corresponding to its sentence
3. List[torch.tensor[int]]: A batch of labels for the target token's lemma
4. List[List[int]]: A batch of UPOS IDs for the target token (this is a List of Lists, not a tensor. It should be padded later.)
5 (Optional): A mapping of label ID to counts in the dataset.
6. Mapping[str, int]: A map between the labels and their indexes
7. Mapping[str, int]: A map between the UPOS tags and their corresponding IDs found in the UPOS batches
"""
if data_path is None or not os.path.exists(data_path):
raise FileNotFoundError(f"Data file {data_path} could not be found.")
if label_decoder is None:
label_decoder = {}
else:
# if labels in the test set aren't in the original model,
# the model will never predict those labels,
# but we can still use those labels in a confusion matrix
label_decoder = dict(label_decoder)
logger.debug("Final label decoder: %s Should be strings to ints", label_decoder)
# words which we are analyzing
target_words = set()
# all known words in the dataset, not just target words
known_words = set()
with open(data_path, "r+", encoding="utf-8") as fin:
sentences, indices, labels, upos_ids, counts, upos_to_id = [], [], [], [], Counter(), {}
input_json = json.load(fin)
sentences_data = input_json['sentences']
self.target_upos = input_json['upos']
for idx, sentence in enumerate(sentences_data):
# TODO Could replace this with sentence.values(), but need to know if Stanza requires Python 3.7 or later for backward compatability reasons
words, target_idx, upos_tags, label = sentence.get("words"), sentence.get("index"), sentence.get("upos_tags"), sentence.get("lemma")
if None in [words, target_idx, upos_tags, label]:
raise ValueError(f"Expected data to be complete but found a null value in sentence {idx}: {sentence}")
label_id = label_decoder.get(label, None)
if label_id is None:
label_decoder[label] = len(label_decoder) # create a new ID for the unknown label
converted_upos_tags = [] # convert upos tags to upos IDs
for upos_tag in upos_tags:
if upos_tag not in upos_to_id:
upos_to_id[upos_tag] = len(upos_to_id) # create a new ID for the unknown UPOS tag
converted_upos_tags.append(upos_to_id[upos_tag])
sentences.append(words)
indices.append(target_idx)
upos_ids.append(converted_upos_tags)
labels.append(label_decoder[label])
if get_counts:
counts[label_decoder[label]] += 1
target_words.add(words[target_idx])
known_words.update(words)
self.sentences = sentences
self.indices = indices
self.upos_ids = upos_ids
self.labels = labels
self.counts = counts
self.label_decoder = label_decoder
self.upos_to_id = upos_to_id
self.batch_size = batch_size
self.shuffle = shuffle
self.known_words = [x.lower() for x in sorted(known_words)]
self.target_words = set(x.lower() for x in target_words)
def __len__(self):
"""
Number of batches, rounded up to nearest batch
"""
return len(self.sentences) // self.batch_size + (len(self.sentences) % self.batch_size > 0)
def __iter__(self):
num_sentences = len(self.sentences)
indices = list(range(num_sentences))
if self.shuffle:
random.shuffle(indices)
for i in range(self.__len__()):
batch_start = self.batch_size * i
batch_end = min(batch_start + self.batch_size, num_sentences)
batch_sentences = [self.sentences[x] for x in indices[batch_start:batch_end]]
batch_indices = torch.tensor([self.indices[x] for x in indices[batch_start:batch_end]])
batch_upos_ids = [self.upos_ids[x] for x in indices[batch_start:batch_end]]
batch_labels = torch.tensor([self.labels[x] for x in indices[batch_start:batch_end]])
yield batch_sentences, batch_indices, batch_upos_ids, batch_labels
def extract_unknown_token_indices(tokenized_indices: torch.tensor, unknown_token_idx: int) -> List[int]:
"""
Extracts the indices within `tokenized_indices` which match `unknown_token_idx`
Args:
tokenized_indices (torch.tensor): A tensor filled with tokenized indices of words that have been mapped to vector indices.
unknown_token_idx (int): The special index for which unknown tokens are marked in the word vectors.
Returns:
List[int]: A list of indices in `tokenized_indices` which match `unknown_token_index`
"""
return [idx for idx, token_index in enumerate(tokenized_indices) if token_index == unknown_token_idx]
def get_device():
"""
Get the device to run computations on
"""
if torch.cuda.is_available:
device = torch.device("cuda")
if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
return device
def round_up_to_multiple(number, multiple):
if multiple == 0:
return "Error: The second number (multiple) cannot be zero."
# Calculate the remainder when dividing the number by the multiple
remainder = number % multiple
# If remainder is non-zero, round up to the next multiple
if remainder != 0:
rounded_number = number + (multiple - remainder)
else:
rounded_number = number # No rounding needed
return rounded_number
def main():
default_test_path = os.path.join(os.path.dirname(__file__), "test_sets", "processed_ud_en", "combined_dev.txt") # get the GUM stuff
sentence_batches, indices_batches, upos_batches, _, counts, _, upos_to_id = load_dataset(default_test_path, get_counts=True)
if __name__ == "__main__":
main()