|
|
import json |
|
|
import random |
|
|
import torch |
|
|
|
|
|
|
|
|
class DataLoader: |
|
|
""" |
|
|
Class for loading language id data and providing batches |
|
|
|
|
|
Attempt to recreate data pre-processing from: https://github.com/AU-DIS/LSTM_langid |
|
|
|
|
|
Uses methods from: https://github.com/AU-DIS/LSTM_langid/blob/main/src/language_datasets.py |
|
|
|
|
|
Data format is same as LSTM_langid |
|
|
""" |
|
|
|
|
|
def __init__(self, device=None): |
|
|
self.batches = None |
|
|
self.batches_iter = None |
|
|
self.tag_to_idx = None |
|
|
self.idx_to_tag = None |
|
|
self.lang_weights = None |
|
|
self.device = device |
|
|
|
|
|
def load_data(self, batch_size, data_files, char_index, tag_index, randomize=False, randomize_range=(5,20), |
|
|
max_length=None): |
|
|
""" |
|
|
Load sequence data and labels, calculate weights for weighted cross entropy loss. |
|
|
Data is stored in a file, 1 example per line |
|
|
Example: {"text": "Hello world.", "label": "en"} |
|
|
""" |
|
|
|
|
|
|
|
|
examples = [] |
|
|
for data_file in data_files: |
|
|
examples += [x for x in open(data_file).read().split("\n") if x.strip()] |
|
|
random.shuffle(examples) |
|
|
examples = [json.loads(x) for x in examples] |
|
|
|
|
|
|
|
|
tag_index = dict(tag_index) |
|
|
new_labels = set([x["label"] for x in examples]) - set(tag_index.keys()) |
|
|
for new_label in new_labels: |
|
|
tag_index[new_label] = len(tag_index) |
|
|
self.tag_to_idx = tag_index |
|
|
self.idx_to_tag = [i[1] for i in sorted([(v,k) for k,v in self.tag_to_idx.items()])] |
|
|
|
|
|
|
|
|
lang_counts = [0 for _ in tag_index] |
|
|
|
|
|
|
|
|
if max_length is not None: |
|
|
examples = [{"text": x["text"][:max_length], "label": x["label"]} for x in examples] |
|
|
|
|
|
|
|
|
if randomize: |
|
|
split_examples = [] |
|
|
for example in examples: |
|
|
sequence = example["text"] |
|
|
label = example["label"] |
|
|
sequences = DataLoader.randomize_data([sequence], upper_lim=randomize_range[1], |
|
|
lower_lim=randomize_range[0]) |
|
|
split_examples += [{"text": seq, "label": label} for seq in sequences] |
|
|
examples = split_examples |
|
|
random.shuffle(examples) |
|
|
|
|
|
|
|
|
batch_lengths = {} |
|
|
for example in examples: |
|
|
sequence = example["text"] |
|
|
label = example["label"] |
|
|
if len(sequence) not in batch_lengths: |
|
|
batch_lengths[len(sequence)] = [] |
|
|
sequence_as_list = [char_index.get(c, char_index["UNK"]) for c in list(sequence)] |
|
|
batch_lengths[len(sequence)].append((sequence_as_list, tag_index[label])) |
|
|
lang_counts[tag_index[label]] += 1 |
|
|
for length in batch_lengths: |
|
|
random.shuffle(batch_lengths[length]) |
|
|
|
|
|
|
|
|
batches = [] |
|
|
for length in batch_lengths: |
|
|
for sublist in [batch_lengths[length][i:i + batch_size] for i in |
|
|
range(0, len(batch_lengths[length]), batch_size)]: |
|
|
batches.append(sublist) |
|
|
|
|
|
self.batches = [self.build_batch_tensors(batch) for batch in batches] |
|
|
|
|
|
|
|
|
most_frequent = max(lang_counts) |
|
|
|
|
|
lang_counts = [(most_frequent * x)/(max(1, x) ** 2) for x in lang_counts] |
|
|
self.lang_weights = torch.tensor(lang_counts, device=self.device, dtype=torch.float) |
|
|
|
|
|
|
|
|
random.shuffle(self.batches) |
|
|
self.batches_iter = iter(self.batches) |
|
|
|
|
|
@staticmethod |
|
|
def randomize_data(sentences, upper_lim=20, lower_lim=5): |
|
|
""" |
|
|
Takes the original data and creates random length examples with length between upper limit and lower limit |
|
|
From LSTM_langid: https://github.com/AU-DIS/LSTM_langid/blob/main/src/language_datasets.py |
|
|
""" |
|
|
|
|
|
new_data = [] |
|
|
for sentence in sentences: |
|
|
remaining = sentence |
|
|
while lower_lim < len(remaining): |
|
|
lim = random.randint(lower_lim, upper_lim) |
|
|
m = min(len(remaining), lim) |
|
|
new_sentence = remaining[:m] |
|
|
new_data.append(new_sentence) |
|
|
split = remaining[m:].split(" ", 1) |
|
|
if len(split) <= 1: |
|
|
break |
|
|
remaining = split[1] |
|
|
random.shuffle(new_data) |
|
|
return new_data |
|
|
|
|
|
def build_batch_tensors(self, batch): |
|
|
""" |
|
|
Helper to turn batches into tensors |
|
|
""" |
|
|
|
|
|
batch_tensors = dict() |
|
|
batch_tensors["sentences"] = torch.tensor([s[0] for s in batch], device=self.device, dtype=torch.long) |
|
|
batch_tensors["targets"] = torch.tensor([s[1] for s in batch], device=self.device, dtype=torch.long) |
|
|
|
|
|
return batch_tensors |
|
|
|
|
|
def next(self): |
|
|
return next(self.batches_iter) |
|
|
|
|
|
|