Commit ·
62dcdc0
1
Parent(s): 2b97908
Added files
Browse files
transformer_from_scratch/config.py
CHANGED
|
@@ -21,7 +21,7 @@ def get_weights_file_path(config, epoch):
|
|
| 21 |
model_folder = config["model_folder"]
|
| 22 |
model_basename = config["model_basename"]
|
| 23 |
model_filename = f"{model_basename}{epoch}.pt"
|
| 24 |
-
return str(Path(
|
| 25 |
|
| 26 |
|
| 27 |
def latest_weights_file_path(config):
|
|
|
|
| 21 |
model_folder = config["model_folder"]
|
| 22 |
model_basename = config["model_basename"]
|
| 23 |
model_filename = f"{model_basename}{epoch}.pt"
|
| 24 |
+
return str(Path('.') / model_folder / model_filename)
|
| 25 |
|
| 26 |
|
| 27 |
def latest_weights_file_path(config):
|
transformer_from_scratch/train.py
CHANGED
|
@@ -3,6 +3,7 @@ from pathlib import Path
|
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
import torchmetrics
|
|
|
|
| 6 |
|
| 7 |
from datasets import load_dataset
|
| 8 |
from tokenizers import Tokenizer
|
|
@@ -88,19 +89,19 @@ def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len,
|
|
| 88 |
|
| 89 |
# Evaluate the character error rate
|
| 90 |
# Compute the char error rate
|
| 91 |
-
metric =
|
| 92 |
cer = metric(predicted, expected)
|
| 93 |
writer.add_scalar('validation cer', cer, global_step)
|
| 94 |
writer.flush()
|
| 95 |
|
| 96 |
# Compute the word error rate
|
| 97 |
-
metric =
|
| 98 |
wer = metric(predicted, expected)
|
| 99 |
writer.add_scalar('validation wer', wer, global_step)
|
| 100 |
writer.flush()
|
| 101 |
|
| 102 |
# Compute the BLEU metric
|
| 103 |
-
metric =
|
| 104 |
bleu = metric(predicted, expected)
|
| 105 |
writer.add_scalar('validation BLEU', bleu, global_step)
|
| 106 |
writer.flush()
|
|
|
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
import torchmetrics
|
| 6 |
+
from torchmetrics.text import BLEUScore, CharErrorRate, WordErrorRate
|
| 7 |
|
| 8 |
from datasets import load_dataset
|
| 9 |
from tokenizers import Tokenizer
|
|
|
|
| 89 |
|
| 90 |
# Evaluate the character error rate
|
| 91 |
# Compute the char error rate
|
| 92 |
+
metric = CharErrorRate()
|
| 93 |
cer = metric(predicted, expected)
|
| 94 |
writer.add_scalar('validation cer', cer, global_step)
|
| 95 |
writer.flush()
|
| 96 |
|
| 97 |
# Compute the word error rate
|
| 98 |
+
metric = WordErrorRate()
|
| 99 |
wer = metric(predicted, expected)
|
| 100 |
writer.add_scalar('validation wer', wer, global_step)
|
| 101 |
writer.flush()
|
| 102 |
|
| 103 |
# Compute the BLEU metric
|
| 104 |
+
metric = BLEUScore()
|
| 105 |
bleu = metric(predicted, expected)
|
| 106 |
writer.add_scalar('validation BLEU', bleu, global_step)
|
| 107 |
writer.flush()
|