flpelerin commited on
Commit
6e1aaa9
·
1 Parent(s): 6d999e4

Update 3 files

Browse files

- /trainer.cli.py
- /trainer.py
- /model.py

Files changed (3) hide show
  1. model.py +1 -0
  2. trainer.cli.py +2 -0
  3. trainer.py +2 -1
model.py CHANGED
@@ -55,6 +55,7 @@ class Model:
55
 
56
  logits = output[0].tolist()
57
  text = tokenizer.decode(logits)
 
58
  return text
59
 
60
 
 
55
 
56
  logits = output[0].tolist()
57
  text = tokenizer.decode(logits)
58
+
59
  return text
60
 
61
 
trainer.cli.py CHANGED
@@ -33,6 +33,8 @@ if __name__ == '__main__':
33
  tokenizer.train(dataset.text, max_length=config.tokenizer.max_length)
34
 
35
  ids = tokenizer.c_encode(dataset.text)
 
 
36
  config.model.params.vocab_size = tokenizer.vocab_size
37
 
38
 
 
33
  tokenizer.train(dataset.text, max_length=config.tokenizer.max_length)
34
 
35
  ids = tokenizer.c_encode(dataset.text)
36
+
37
+ config.model.tokenizer = tokenizer
38
  config.model.params.vocab_size = tokenizer.vocab_size
39
 
40
 
trainer.py CHANGED
@@ -12,6 +12,8 @@ class Trainer:
12
  def log(self, loss: float):
13
  print(f"Epoch: {self.epoch} / {self.num_epochs}\t\tBatch: {self.batch} / {self.num_batches}\t\tLoss: {round(loss, 4)}")
14
 
 
 
15
 
16
 
17
  def train(self, batches):
@@ -29,4 +31,3 @@ class Trainer:
29
  self.optimizer.step()
30
 
31
  self.log(loss.item())
32
- #Train.LogStep(infer_config, log_config, epoch, num_epochs, batch, num_batches, loss)
 
12
  def log(self, loss: float):
13
  print(f"Epoch: {self.epoch} / {self.num_epochs}\t\tBatch: {self.batch} / {self.num_batches}\t\tLoss: {round(loss, 4)}")
14
 
15
+ if batch % 20 == 0:
16
+ print(f'{model.generate_text(self.model.tokenizer, self.config.inference.seed_text, self.config.inference.n_predict)}')
17
 
18
 
19
  def train(self, batches):
 
31
  self.optimizer.step()
32
 
33
  self.log(loss.item())