File size: 1,098 Bytes
fd0332a
7f3b4e4
 
fd0332a
52a257c
fd0332a
 
 
 
 
 
 
 
 
 
52a257c
 
 
 
1907275
 
52a257c
 
 
 
 
 
 
 
 
 
 
 
6a41ff7
52a257c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from datasets import load_dataset
import numpy as np
import torch

from util import Config, GetDevice


class Dataset:
    def __init__(self, config: Config):
        self.__dict__ = dict(config.__dict__)

        self.dataset = load_dataset(self.remote_path)
        self.text = ''.join(s for s in self.dataset['train']['text']).encode('ascii', 'ignore').decode('ascii')


    #def __iadd__(self, args):
    #    name, value = args
    #    setattr(self, name, value)
    #    return self


    def batch(self, ids):
        if not isinstance(ids, np.ndarray):
            ids = np.array(ids)

        num_batches = len(ids) // (self.seq_length * self.batch_size)
        total_elements = num_batches * self.seq_length * self.batch_size

        trimmed_array = ids[:total_elements]
        array_reshaped = trimmed_array.reshape((num_batches, self.batch_size, self.seq_length))

        batches = []
        for batch in array_reshaped:
            tensor_batch = torch.tensor(batch, dtype=torch.long).to(GetDevice())
            batches.append(tensor_batch)

        return batches, num_batches