Update 3 files
Browse files- /trainer.cli.py
- /trainer.py
- /model.py
- model.py +1 -0
- trainer.cli.py +2 -0
- trainer.py +2 -1
model.py
CHANGED
|
@@ -55,6 +55,7 @@ class Model:
|
|
| 55 |
|
| 56 |
logits = output[0].tolist()
|
| 57 |
text = tokenizer.decode(logits)
|
|
|
|
| 58 |
return text
|
| 59 |
|
| 60 |
|
|
|
|
| 55 |
|
| 56 |
logits = output[0].tolist()
|
| 57 |
text = tokenizer.decode(logits)
|
| 58 |
+
|
| 59 |
return text
|
| 60 |
|
| 61 |
|
trainer.cli.py
CHANGED
|
@@ -33,6 +33,8 @@ if __name__ == '__main__':
|
|
| 33 |
tokenizer.train(dataset.text, max_length=config.tokenizer.max_length)
|
| 34 |
|
| 35 |
ids = tokenizer.c_encode(dataset.text)
|
|
|
|
|
|
|
| 36 |
config.model.params.vocab_size = tokenizer.vocab_size
|
| 37 |
|
| 38 |
|
|
|
|
| 33 |
tokenizer.train(dataset.text, max_length=config.tokenizer.max_length)
|
| 34 |
|
| 35 |
ids = tokenizer.c_encode(dataset.text)
|
| 36 |
+
|
| 37 |
+
config.model.tokenizer = tokenizer
|
| 38 |
config.model.params.vocab_size = tokenizer.vocab_size
|
| 39 |
|
| 40 |
|
trainer.py
CHANGED
|
@@ -12,6 +12,8 @@ class Trainer:
|
|
| 12 |
def log(self, loss: float):
|
| 13 |
print(f"Epoch: {self.epoch} / {self.num_epochs}\t\tBatch: {self.batch} / {self.num_batches}\t\tLoss: {round(loss, 4)}")
|
| 14 |
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
def train(self, batches):
|
|
@@ -29,4 +31,3 @@ class Trainer:
|
|
| 29 |
self.optimizer.step()
|
| 30 |
|
| 31 |
self.log(loss.item())
|
| 32 |
-
#Train.LogStep(infer_config, log_config, epoch, num_epochs, batch, num_batches, loss)
|
|
|
|
| 12 |
def log(self, loss: float):
|
| 13 |
print(f"Epoch: {self.epoch} / {self.num_epochs}\t\tBatch: {self.batch} / {self.num_batches}\t\tLoss: {round(loss, 4)}")
|
| 14 |
|
| 15 |
+
if batch % 20 == 0:
|
| 16 |
+
print(f'{model.generate_text(self.model.tokenizer, self.config.inference.seed_text, self.config.inference.n_predict)}')
|
| 17 |
|
| 18 |
|
| 19 |
def train(self, batches):
|
|
|
|
| 31 |
self.optimizer.step()
|
| 32 |
|
| 33 |
self.log(loss.item())
|
|
|