Update 3 files
Browse files- /trainer.py
- /dataset.py
- /trainer.cli.py
- dataset.py +21 -7
- trainer.cli.py +4 -2
- trainer.py +1 -0
dataset.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from datasets import load_dataset
|
| 2 |
|
| 3 |
-
from util import Config
|
| 4 |
|
| 5 |
|
| 6 |
class Dataset:
|
|
@@ -11,11 +11,25 @@ class Dataset:
|
|
| 11 |
self.text = ''.join(s for s in self.dataset['train']['text']).encode('ascii', 'ignore').decode('ascii')
|
| 12 |
|
| 13 |
|
| 14 |
-
def __iadd__(self, args):
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
|
| 19 |
|
| 20 |
-
def batch(self,
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from datasets import load_dataset
|
| 2 |
|
| 3 |
+
from util import Config, GetDevice
|
| 4 |
|
| 5 |
|
| 6 |
class Dataset:
|
|
|
|
| 11 |
self.text = ''.join(s for s in self.dataset['train']['text']).encode('ascii', 'ignore').decode('ascii')
|
| 12 |
|
| 13 |
|
| 14 |
+
#def __iadd__(self, args):
|
| 15 |
+
# name, value = args
|
| 16 |
+
# setattr(self, name, value)
|
| 17 |
+
# return self
|
| 18 |
|
| 19 |
|
| 20 |
+
def batch(self, ids):
|
| 21 |
+
if not isinstance(ids, np.ndarray):
|
| 22 |
+
ids = np.array(ids)
|
| 23 |
+
|
| 24 |
+
num_batches = len(ids) // (self.seq_length * self.batch_size)
|
| 25 |
+
total_elements = num_batches * self.seq_length * self.batch_size
|
| 26 |
+
|
| 27 |
+
trimmed_array = ids[:total_elements]
|
| 28 |
+
array_reshaped = trimmed_array.reshape((num_batches, self.batch_size, self.seq_length))
|
| 29 |
+
|
| 30 |
+
batches = []
|
| 31 |
+
for batch in array_reshaped:
|
| 32 |
+
tensor_batch = torch.tensor(batch, dtype=torch.long).to(GetDevice())
|
| 33 |
+
batches.append(tensor_batch)
|
| 34 |
+
|
| 35 |
+
return batches, num_batches
|
trainer.cli.py
CHANGED
|
@@ -29,11 +29,13 @@ if __name__ == '__main__':
|
|
| 29 |
|
| 30 |
tokenizer = Tokenizer()
|
| 31 |
tokenizer.train(dataset.text, max_length=config.tokenizer.max_length)
|
|
|
|
| 32 |
ids = tokenizer.c_encode(dataset.text)
|
|
|
|
|
|
|
| 33 |
|
|
|
|
| 34 |
|
| 35 |
-
dataset += ("ids", ids)
|
| 36 |
-
#dataset.batch(ids)
|
| 37 |
|
| 38 |
print(f"dataset ids: {dataset.ids}")
|
| 39 |
|
|
|
|
| 29 |
|
| 30 |
tokenizer = Tokenizer()
|
| 31 |
tokenizer.train(dataset.text, max_length=config.tokenizer.max_length)
|
| 32 |
+
|
| 33 |
ids = tokenizer.c_encode(dataset.text)
|
| 34 |
+
config.model.params.vocab_size = tokenizer.vocab_size
|
| 35 |
+
|
| 36 |
|
| 37 |
+
batches, num_batches = dataset.batch(ids)
|
| 38 |
|
|
|
|
|
|
|
| 39 |
|
| 40 |
print(f"dataset ids: {dataset.ids}")
|
| 41 |
|
trainer.py
CHANGED
|
@@ -9,6 +9,7 @@ class Trainer:
|
|
| 9 |
self.__dict__ = dict(config.__dict__)
|
| 10 |
|
| 11 |
#self.wandb = Wandb(config.wandb)
|
|
|
|
| 12 |
|
| 13 |
self.model = Model(config.model)
|
| 14 |
|
|
|
|
| 9 |
self.__dict__ = dict(config.__dict__)
|
| 10 |
|
| 11 |
#self.wandb = Wandb(config.wandb)
|
| 12 |
+
|
| 13 |
|
| 14 |
self.model = Model(config.model)
|
| 15 |
|