Vaishnav14220
commited on
Commit
·
55ff09b
1
Parent(s):
9b99b56
Use sacrebleu directly for training metrics
Browse files- src/train_forward.py +7 -9
- src/train_retro.py +7 -9
src/train_forward.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
Forward synthesis model training script.
|
| 3 |
Trains T5 model to predict products from reactants.
|
| 4 |
"""
|
| 5 |
-
import
|
| 6 |
import numpy as np
|
| 7 |
from transformers import (
|
| 8 |
AutoTokenizer,
|
|
@@ -71,23 +71,21 @@ def main():
|
|
| 71 |
collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)
|
| 72 |
|
| 73 |
# Metrics
|
| 74 |
-
metric = evaluate.load("sacrebleu")
|
| 75 |
-
|
| 76 |
def compute_metrics(eval_pred):
|
| 77 |
preds, labels = eval_pred
|
| 78 |
preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
|
| 79 |
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
| 80 |
-
|
| 81 |
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
| 82 |
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
| 83 |
-
|
| 84 |
decoded_preds = [p.strip() for p in decoded_preds]
|
| 85 |
decoded_labels = [l.strip() for l in decoded_labels]
|
| 86 |
-
|
| 87 |
-
bleu =
|
| 88 |
exact = np.mean([p == l for p, l in zip(decoded_preds, decoded_labels)])
|
| 89 |
-
|
| 90 |
-
return {"bleu": bleu
|
| 91 |
|
| 92 |
# Trainer
|
| 93 |
print("\nInitializing trainer...")
|
|
|
|
| 2 |
Forward synthesis model training script.
|
| 3 |
Trains T5 model to predict products from reactants.
|
| 4 |
"""
|
| 5 |
+
import sacrebleu
|
| 6 |
import numpy as np
|
| 7 |
from transformers import (
|
| 8 |
AutoTokenizer,
|
|
|
|
| 71 |
collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)
|
| 72 |
|
| 73 |
# Metrics
|
|
|
|
|
|
|
| 74 |
def compute_metrics(eval_pred):
|
| 75 |
preds, labels = eval_pred
|
| 76 |
preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
|
| 77 |
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
| 78 |
+
|
| 79 |
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
| 80 |
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
| 81 |
+
|
| 82 |
decoded_preds = [p.strip() for p in decoded_preds]
|
| 83 |
decoded_labels = [l.strip() for l in decoded_labels]
|
| 84 |
+
|
| 85 |
+
bleu = sacrebleu.corpus_bleu(decoded_preds, [decoded_labels])
|
| 86 |
exact = np.mean([p == l for p, l in zip(decoded_preds, decoded_labels)])
|
| 87 |
+
|
| 88 |
+
return {"bleu": bleu.score, "exact_match": exact}
|
| 89 |
|
| 90 |
# Trainer
|
| 91 |
print("\nInitializing trainer...")
|
src/train_retro.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
Retrosynthesis model training script.
|
| 3 |
Trains T5 model to predict reactants from products.
|
| 4 |
"""
|
| 5 |
-
import
|
| 6 |
import numpy as np
|
| 7 |
from transformers import (
|
| 8 |
AutoTokenizer,
|
|
@@ -71,23 +71,21 @@ def main():
|
|
| 71 |
collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)
|
| 72 |
|
| 73 |
# Metrics
|
| 74 |
-
metric = evaluate.load("sacrebleu")
|
| 75 |
-
|
| 76 |
def compute_metrics(eval_pred):
|
| 77 |
preds, labels = eval_pred
|
| 78 |
preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
|
| 79 |
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
| 80 |
-
|
| 81 |
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
| 82 |
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
| 83 |
-
|
| 84 |
decoded_preds = [p.strip() for p in decoded_preds]
|
| 85 |
decoded_labels = [l.strip() for l in decoded_labels]
|
| 86 |
-
|
| 87 |
-
bleu =
|
| 88 |
exact = np.mean([p == l for p, l in zip(decoded_preds, decoded_labels)])
|
| 89 |
-
|
| 90 |
-
return {"bleu": bleu
|
| 91 |
|
| 92 |
# Trainer
|
| 93 |
print("\nInitializing trainer...")
|
|
|
|
| 2 |
Retrosynthesis model training script.
|
| 3 |
Trains T5 model to predict reactants from products.
|
| 4 |
"""
|
| 5 |
+
import sacrebleu
|
| 6 |
import numpy as np
|
| 7 |
from transformers import (
|
| 8 |
AutoTokenizer,
|
|
|
|
| 71 |
collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)
|
| 72 |
|
| 73 |
# Metrics
|
|
|
|
|
|
|
| 74 |
def compute_metrics(eval_pred):
|
| 75 |
preds, labels = eval_pred
|
| 76 |
preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
|
| 77 |
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
| 78 |
+
|
| 79 |
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
| 80 |
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
| 81 |
+
|
| 82 |
decoded_preds = [p.strip() for p in decoded_preds]
|
| 83 |
decoded_labels = [l.strip() for l in decoded_labels]
|
| 84 |
+
|
| 85 |
+
bleu = sacrebleu.corpus_bleu(decoded_preds, [decoded_labels])
|
| 86 |
exact = np.mean([p == l for p, l in zip(decoded_preds, decoded_labels)])
|
| 87 |
+
|
| 88 |
+
return {"bleu": bleu.score, "exact_match": exact}
|
| 89 |
|
| 90 |
# Trainer
|
| 91 |
print("\nInitializing trainer...")
|