File size: 5,206 Bytes
19b8775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import json
import random
import torch


class DataLoader:
    """
    Class for loading language id data and providing batches

    Attempt to recreate data pre-processing from: https://github.com/AU-DIS/LSTM_langid

    Uses methods from: https://github.com/AU-DIS/LSTM_langid/blob/main/src/language_datasets.py

    Data format is same as LSTM_langid
    """

    def __init__(self, device=None):
        self.batches = None
        self.batches_iter = None
        self.tag_to_idx = None
        self.idx_to_tag = None
        self.lang_weights = None
        self.device = device

    def load_data(self, batch_size, data_files, char_index, tag_index, randomize=False, randomize_range=(5,20),
                  max_length=None):
        """
        Load sequence data and labels, calculate weights for weighted cross entropy loss.
        Data is stored in a file, 1 example per line
        Example: {"text": "Hello world.", "label": "en"}
        """

        # set up examples from data files
        examples = []
        for data_file in data_files:
            examples += [x for x in open(data_file).read().split("\n") if x.strip()]
        random.shuffle(examples)
        examples = [json.loads(x) for x in examples]

        # add additional labels in this data set to tag index
        tag_index = dict(tag_index)
        new_labels = set([x["label"] for x in examples]) - set(tag_index.keys())
        for new_label in new_labels:
            tag_index[new_label] = len(tag_index)
        self.tag_to_idx = tag_index
        self.idx_to_tag = [i[1] for i in sorted([(v,k) for k,v in self.tag_to_idx.items()])]
        
        # set up lang counts used for weights for cross entropy loss
        lang_counts = [0 for _ in tag_index]

        # optionally limit text to max length
        if max_length is not None:
            examples = [{"text": x["text"][:max_length], "label": x["label"]} for x in examples]

        # randomize data
        if randomize:
            split_examples = []
            for example in examples:
                sequence = example["text"]
                label = example["label"]
                sequences = DataLoader.randomize_data([sequence], upper_lim=randomize_range[1], 
                                                      lower_lim=randomize_range[0])
                split_examples += [{"text": seq, "label": label} for seq in sequences]
            examples = split_examples
            random.shuffle(examples)

        # break into equal length batches
        batch_lengths = {}
        for example in examples:
            sequence = example["text"]
            label = example["label"]
            if len(sequence) not in batch_lengths:
                batch_lengths[len(sequence)] = []
            sequence_as_list = [char_index.get(c, char_index["UNK"]) for c in list(sequence)]
            batch_lengths[len(sequence)].append((sequence_as_list, tag_index[label]))
            lang_counts[tag_index[label]] += 1
        for length in batch_lengths:
            random.shuffle(batch_lengths[length])

        # create final set of batches
        batches = []
        for length in batch_lengths:
            for sublist in [batch_lengths[length][i:i + batch_size] for i in
                            range(0, len(batch_lengths[length]), batch_size)]:
                batches.append(sublist)

        self.batches = [self.build_batch_tensors(batch) for batch in batches]

        # set up lang weights
        most_frequent = max(lang_counts)
        # set to 0.0 if lang_count is 0 or most_frequent/lang_count otherwise
        lang_counts = [(most_frequent * x)/(max(1, x) ** 2) for x in lang_counts]
        self.lang_weights = torch.tensor(lang_counts, device=self.device, dtype=torch.float)

        # shuffle batches to mix up lengths
        random.shuffle(self.batches)
        self.batches_iter = iter(self.batches)

    @staticmethod
    def randomize_data(sentences, upper_lim=20, lower_lim=5):
        """
        Takes the original data and creates random length examples with length between upper limit and lower limit
        From LSTM_langid: https://github.com/AU-DIS/LSTM_langid/blob/main/src/language_datasets.py
        """

        new_data = []
        for sentence in sentences:
            remaining = sentence
            while lower_lim < len(remaining):
                lim = random.randint(lower_lim, upper_lim)
                m = min(len(remaining), lim)
                new_sentence = remaining[:m]
                new_data.append(new_sentence)
                split = remaining[m:].split(" ", 1)
                if len(split) <= 1:
                    break
                remaining = split[1]
        random.shuffle(new_data)
        return new_data

    def build_batch_tensors(self, batch):
        """
        Helper to turn batches into tensors
        """

        batch_tensors = dict()
        batch_tensors["sentences"] = torch.tensor([s[0] for s in batch], device=self.device, dtype=torch.long)
        batch_tensors["targets"] = torch.tensor([s[1] for s in batch], device=self.device, dtype=torch.long)

        return batch_tensors

    def next(self):
        return next(self.batches_iter)