| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import unittest |
| |
|
| | import tests.utils as test_utils |
| | import torch |
| | from fairseq.sequence_scorer import SequenceScorer |
| |
|
| |
|
| | class TestSequenceScorer(unittest.TestCase): |
| | def test_sequence_scorer(self): |
| | |
| | d = test_utils.dummy_dictionary(vocab_size=2) |
| | self.assertEqual(d.pad(), 1) |
| | self.assertEqual(d.eos(), 2) |
| | self.assertEqual(d.unk(), 3) |
| | eos = d.eos() |
| | w1 = 4 |
| | w2 = 5 |
| |
|
| | |
| | data = [ |
| | { |
| | "source": torch.LongTensor([w1, w2, eos]), |
| | "target": torch.LongTensor([w1, w2, w1, eos]), |
| | }, |
| | { |
| | "source": torch.LongTensor([w2, eos]), |
| | "target": torch.LongTensor([w2, w1, eos]), |
| | }, |
| | { |
| | "source": torch.LongTensor([w2, eos]), |
| | "target": torch.LongTensor([w2, eos]), |
| | }, |
| | ] |
| | data_itr = test_utils.dummy_dataloader(data) |
| |
|
| | |
| | args = argparse.Namespace() |
| | unk = 0.0 |
| | args.beam_probs = [ |
| | |
| | torch.FloatTensor( |
| | [ |
| | |
| | [0.0, unk, 0.6, 0.4], |
| | [0.0, unk, 0.4, 0.6], |
| | [0.0, unk, 0.7, 0.3], |
| | ] |
| | ), |
| | |
| | torch.FloatTensor( |
| | [ |
| | |
| | [0.0, unk, 0.2, 0.7], |
| | [0.0, unk, 0.8, 0.2], |
| | [0.7, unk, 0.1, 0.2], |
| | ] |
| | ), |
| | |
| | torch.FloatTensor( |
| | [ |
| | |
| | [0.10, unk, 0.50, 0.4], |
| | [0.15, unk, 0.15, 0.7], |
| | [0.00, unk, 0.00, 0.0], |
| | ] |
| | ), |
| | |
| | torch.FloatTensor( |
| | [ |
| | |
| | [0.9, unk, 0.05, 0.05], |
| | [0.0, unk, 0.00, 0.0], |
| | [0.0, unk, 0.00, 0.0], |
| | ] |
| | ), |
| | ] |
| | expected_scores = [ |
| | [0.6, 0.7, 0.5, 0.9], |
| | [0.6, 0.8, 0.15], |
| | [0.3, 0.7], |
| | ] |
| |
|
| | task = test_utils.TestTranslationTask.setup_task(args, d, d) |
| | model = task.build_model(args) |
| | scorer = SequenceScorer(task.target_dictionary) |
| | for sample in data_itr: |
| | hypos = task.inference_step(scorer, [model], sample) |
| | for id, hypos_id in zip(sample["id"].tolist(), hypos): |
| | self.assertHypoTokens(hypos_id[0], data[id]["target"]) |
| | self.assertHypoScore(hypos_id[0], expected_scores[id]) |
| |
|
| | def assertHypoTokens(self, hypo, tokens): |
| | self.assertTensorEqual(hypo["tokens"], torch.LongTensor(tokens)) |
| |
|
| | def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.0): |
| | pos_scores = torch.FloatTensor(pos_probs).log() |
| | self.assertAlmostEqual(hypo["positional_scores"], pos_scores) |
| | self.assertEqual(pos_scores.numel(), hypo["tokens"].numel()) |
| | score = pos_scores.sum() |
| | if normalized: |
| | score /= pos_scores.numel() ** lenpen |
| | self.assertLess(abs(score - hypo["score"]), 1e-6) |
| |
|
| | def assertAlmostEqual(self, t1, t2): |
| | self.assertEqual(t1.size(), t2.size(), "size mismatch") |
| | self.assertLess((t1 - t2).abs().max(), 1e-4) |
| |
|
| | def assertTensorEqual(self, t1, t2): |
| | self.assertEqual(t1.size(), t2.size(), "size mismatch") |
| | self.assertEqual(t1.ne(t2).long().sum(), 0) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | unittest.main() |
| |
|