File size: 5,206 Bytes
19b8775 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import json
import random
import torch
class DataLoader:
"""
Class for loading language id data and providing batches
Attempt to recreate data pre-processing from: https://github.com/AU-DIS/LSTM_langid
Uses methods from: https://github.com/AU-DIS/LSTM_langid/blob/main/src/language_datasets.py
Data format is same as LSTM_langid
"""
def __init__(self, device=None):
self.batches = None
self.batches_iter = None
self.tag_to_idx = None
self.idx_to_tag = None
self.lang_weights = None
self.device = device
def load_data(self, batch_size, data_files, char_index, tag_index, randomize=False, randomize_range=(5,20),
max_length=None):
"""
Load sequence data and labels, calculate weights for weighted cross entropy loss.
Data is stored in a file, 1 example per line
Example: {"text": "Hello world.", "label": "en"}
"""
# set up examples from data files
examples = []
for data_file in data_files:
examples += [x for x in open(data_file).read().split("\n") if x.strip()]
random.shuffle(examples)
examples = [json.loads(x) for x in examples]
# add additional labels in this data set to tag index
tag_index = dict(tag_index)
new_labels = set([x["label"] for x in examples]) - set(tag_index.keys())
for new_label in new_labels:
tag_index[new_label] = len(tag_index)
self.tag_to_idx = tag_index
self.idx_to_tag = [i[1] for i in sorted([(v,k) for k,v in self.tag_to_idx.items()])]
# set up lang counts used for weights for cross entropy loss
lang_counts = [0 for _ in tag_index]
# optionally limit text to max length
if max_length is not None:
examples = [{"text": x["text"][:max_length], "label": x["label"]} for x in examples]
# randomize data
if randomize:
split_examples = []
for example in examples:
sequence = example["text"]
label = example["label"]
sequences = DataLoader.randomize_data([sequence], upper_lim=randomize_range[1],
lower_lim=randomize_range[0])
split_examples += [{"text": seq, "label": label} for seq in sequences]
examples = split_examples
random.shuffle(examples)
# break into equal length batches
batch_lengths = {}
for example in examples:
sequence = example["text"]
label = example["label"]
if len(sequence) not in batch_lengths:
batch_lengths[len(sequence)] = []
sequence_as_list = [char_index.get(c, char_index["UNK"]) for c in list(sequence)]
batch_lengths[len(sequence)].append((sequence_as_list, tag_index[label]))
lang_counts[tag_index[label]] += 1
for length in batch_lengths:
random.shuffle(batch_lengths[length])
# create final set of batches
batches = []
for length in batch_lengths:
for sublist in [batch_lengths[length][i:i + batch_size] for i in
range(0, len(batch_lengths[length]), batch_size)]:
batches.append(sublist)
self.batches = [self.build_batch_tensors(batch) for batch in batches]
# set up lang weights
most_frequent = max(lang_counts)
# set to 0.0 if lang_count is 0 or most_frequent/lang_count otherwise
lang_counts = [(most_frequent * x)/(max(1, x) ** 2) for x in lang_counts]
self.lang_weights = torch.tensor(lang_counts, device=self.device, dtype=torch.float)
# shuffle batches to mix up lengths
random.shuffle(self.batches)
self.batches_iter = iter(self.batches)
@staticmethod
def randomize_data(sentences, upper_lim=20, lower_lim=5):
"""
Takes the original data and creates random length examples with length between upper limit and lower limit
From LSTM_langid: https://github.com/AU-DIS/LSTM_langid/blob/main/src/language_datasets.py
"""
new_data = []
for sentence in sentences:
remaining = sentence
while lower_lim < len(remaining):
lim = random.randint(lower_lim, upper_lim)
m = min(len(remaining), lim)
new_sentence = remaining[:m]
new_data.append(new_sentence)
split = remaining[m:].split(" ", 1)
if len(split) <= 1:
break
remaining = split[1]
random.shuffle(new_data)
return new_data
def build_batch_tensors(self, batch):
"""
Helper to turn batches into tensors
"""
batch_tensors = dict()
batch_tensors["sentences"] = torch.tensor([s[0] for s in batch], device=self.device, dtype=torch.long)
batch_tensors["targets"] = torch.tensor([s[1] for s in batch], device=self.device, dtype=torch.long)
return batch_tensors
def next(self):
return next(self.batches_iter)
|