AlexSychovUN commited on
Commit
62dcdc0
·
1 Parent(s): 2b97908

Added files

Browse files
transformer_from_scratch/config.py CHANGED
@@ -21,7 +21,7 @@ def get_weights_file_path(config, epoch):
21
  model_folder = config["model_folder"]
22
  model_basename = config["model_basename"]
23
  model_filename = f"{model_basename}{epoch}.pt"
24
- return str(Path(".") / model_folder / model_basename / model_filename)
25
 
26
 
27
  def latest_weights_file_path(config):
 
21
  model_folder = config["model_folder"]
22
  model_basename = config["model_basename"]
23
  model_filename = f"{model_basename}{epoch}.pt"
24
+ return str(Path('.') / model_folder / model_filename)
25
 
26
 
27
  def latest_weights_file_path(config):
transformer_from_scratch/train.py CHANGED
@@ -3,6 +3,7 @@ from pathlib import Path
3
  import torch
4
  import torch.nn as nn
5
  import torchmetrics
 
6
 
7
  from datasets import load_dataset
8
  from tokenizers import Tokenizer
@@ -88,19 +89,19 @@ def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len,
88
 
89
  # Evaluate the character error rate
90
  # Compute the char error rate
91
- metric = torchmetrics.CharErrorRate()
92
  cer = metric(predicted, expected)
93
  writer.add_scalar('validation cer', cer, global_step)
94
  writer.flush()
95
 
96
  # Compute the word error rate
97
- metric = torchmetrics.WordErrorRate()
98
  wer = metric(predicted, expected)
99
  writer.add_scalar('validation wer', wer, global_step)
100
  writer.flush()
101
 
102
  # Compute the BLEU metric
103
- metric = torchmetrics.BLEUScore()
104
  bleu = metric(predicted, expected)
105
  writer.add_scalar('validation BLEU', bleu, global_step)
106
  writer.flush()
 
3
  import torch
4
  import torch.nn as nn
5
  import torchmetrics
6
+ from torchmetrics.text import BLEUScore, CharErrorRate, WordErrorRate
7
 
8
  from datasets import load_dataset
9
  from tokenizers import Tokenizer
 
89
 
90
  # Evaluate the character error rate
91
  # Compute the char error rate
92
+ metric = CharErrorRate()
93
  cer = metric(predicted, expected)
94
  writer.add_scalar('validation cer', cer, global_step)
95
  writer.flush()
96
 
97
  # Compute the word error rate
98
+ metric = WordErrorRate()
99
  wer = metric(predicted, expected)
100
  writer.add_scalar('validation wer', wer, global_step)
101
  writer.flush()
102
 
103
  # Compute the BLEU metric
104
+ metric = BLEUScore()
105
  bleu = metric(predicted, expected)
106
  writer.add_scalar('validation BLEU', bleu, global_step)
107
  writer.flush()