| import json |
| import torch |
|
|
| from bleu import list_bleu |
|
|
| def is_rank_0(): |
| if torch.distributed.is_initialized(): |
| if torch.distributed.get_rank() == 0: |
| return True |
| else: |
| return True |
| return False |
|
|
| class TextGenerationScorer: |
| def __init__(self, tokenizer, bos_id, eos_id, output_path): |
| self.bos_id = bos_id |
| self.eos_id = eos_id |
| self.output_path = output_path |
| self.tokenizer = tokenizer |
|
|
| def __call__(self, prediction): |
| preds = prediction.predictions |
| preds_size = prediction.predictions_size |
| label_ids = prediction.label_ids |
| label_size = prediction.label_size |
| p_start, l_start = 0, 0 |
| correct, total = 0, 0 |
| ref = [] |
| hyp = [] |
| if is_rank_0(): |
| fout = open(self.output_path, "w") |
| for idx, (p_size, l_size) in enumerate(zip(preds_size, label_size)): |
| p_end = p_start + p_size |
| l_end = l_start + l_size |
| pred = self.get_sequence(preds[p_start: p_end]) |
| label = self.get_sequence(label_ids[l_start: l_end]) |
| p_start = p_end |
| l_start = l_end |
| if pred == label: |
| correct += 1 |
| total += 1 |
| if is_rank_0(): |
| pred_text = self.tokenizer.decode(pred, skip_special_tokens=True, clean_up_tokenization_spaces=True).strip() |
| label_text = self.tokenizer.decode(label, skip_special_tokens=True, clean_up_tokenization_spaces=True).strip() |
| ref.append(label_text) |
| hyp.append(pred_text) |
| fout.write( |
| json.dumps({ |
| "idx": idx, |
| "pred": pred_text, |
| "label": label_text}) + "\n") |
| score = list_bleu([ref], hyp) |
| return { |
| "bleu": score, |
| "accuracy": correct / total, |
| "correct": correct, |
| "total": total |
| } |
|
|
|
|
| def get_sequence(self, seq): |
| processed_seq = [] |
| for idx in seq: |
| if idx == self.bos_id: |
| continue |
| if idx == self.eos_id: |
| break |
| processed_seq.append(int(idx)) |
| return processed_seq |
|
|
|
|
|
|
|
|
|
|