Vaishnav14220 commited on
Commit
55ff09b
·
1 Parent(s): 9b99b56

Use sacrebleu directly for training metrics

Browse files
Files changed (2) hide show
  1. src/train_forward.py +7 -9
  2. src/train_retro.py +7 -9
src/train_forward.py CHANGED
@@ -2,7 +2,7 @@
2
  Forward synthesis model training script.
3
  Trains T5 model to predict products from reactants.
4
  """
5
- import evaluate
6
  import numpy as np
7
  from transformers import (
8
  AutoTokenizer,
@@ -71,23 +71,21 @@ def main():
71
  collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)
72
 
73
  # Metrics
74
- metric = evaluate.load("sacrebleu")
75
-
76
  def compute_metrics(eval_pred):
77
  preds, labels = eval_pred
78
  preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
79
  labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
80
-
81
  decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
82
  decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
83
-
84
  decoded_preds = [p.strip() for p in decoded_preds]
85
  decoded_labels = [l.strip() for l in decoded_labels]
86
-
87
- bleu = metric.compute(predictions=decoded_preds, references=[[l] for l in decoded_labels])
88
  exact = np.mean([p == l for p, l in zip(decoded_preds, decoded_labels)])
89
-
90
- return {"bleu": bleu["score"], "exact_match": exact}
91
 
92
  # Trainer
93
  print("\nInitializing trainer...")
 
2
  Forward synthesis model training script.
3
  Trains T5 model to predict products from reactants.
4
  """
5
+ import sacrebleu
6
  import numpy as np
7
  from transformers import (
8
  AutoTokenizer,
 
71
  collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)
72
 
73
  # Metrics
 
 
74
  def compute_metrics(eval_pred):
75
  preds, labels = eval_pred
76
  preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
77
  labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
78
+
79
  decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
80
  decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
81
+
82
  decoded_preds = [p.strip() for p in decoded_preds]
83
  decoded_labels = [l.strip() for l in decoded_labels]
84
+
85
+ bleu = sacrebleu.corpus_bleu(decoded_preds, [decoded_labels])
86
  exact = np.mean([p == l for p, l in zip(decoded_preds, decoded_labels)])
87
+
88
+ return {"bleu": bleu.score, "exact_match": exact}
89
 
90
  # Trainer
91
  print("\nInitializing trainer...")
src/train_retro.py CHANGED
@@ -2,7 +2,7 @@
2
  Retrosynthesis model training script.
3
  Trains T5 model to predict reactants from products.
4
  """
5
- import evaluate
6
  import numpy as np
7
  from transformers import (
8
  AutoTokenizer,
@@ -71,23 +71,21 @@ def main():
71
  collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)
72
 
73
  # Metrics
74
- metric = evaluate.load("sacrebleu")
75
-
76
  def compute_metrics(eval_pred):
77
  preds, labels = eval_pred
78
  preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
79
  labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
80
-
81
  decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
82
  decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
83
-
84
  decoded_preds = [p.strip() for p in decoded_preds]
85
  decoded_labels = [l.strip() for l in decoded_labels]
86
-
87
- bleu = metric.compute(predictions=decoded_preds, references=[[l] for l in decoded_labels])
88
  exact = np.mean([p == l for p, l in zip(decoded_preds, decoded_labels)])
89
-
90
- return {"bleu": bleu["score"], "exact_match": exact}
91
 
92
  # Trainer
93
  print("\nInitializing trainer...")
 
2
  Retrosynthesis model training script.
3
  Trains T5 model to predict reactants from products.
4
  """
5
+ import sacrebleu
6
  import numpy as np
7
  from transformers import (
8
  AutoTokenizer,
 
71
  collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)
72
 
73
  # Metrics
 
 
74
  def compute_metrics(eval_pred):
75
  preds, labels = eval_pred
76
  preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
77
  labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
78
+
79
  decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
80
  decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
81
+
82
  decoded_preds = [p.strip() for p in decoded_preds]
83
  decoded_labels = [l.strip() for l in decoded_labels]
84
+
85
+ bleu = sacrebleu.corpus_bleu(decoded_preds, [decoded_labels])
86
  exact = np.mean([p == l for p, l in zip(decoded_preds, decoded_labels)])
87
+
88
+ return {"bleu": bleu.score, "exact_match": exact}
89
 
90
  # Trainer
91
  print("\nInitializing trainer...")