|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
|
|
from cer import CER |
|
|
|
|
|
|
|
|
cer = CER() |
|
|
|
|
|
|
|
|
class TestCER(unittest.TestCase): |
|
|
def test_cer_case_sensitive(self): |
|
|
refs = ["White House"] |
|
|
preds = ["white house"] |
|
|
|
|
|
char_error_rate = cer.compute(predictions=preds, references=refs) |
|
|
self.assertTrue(abs(char_error_rate - 0.1818181818) < 1e-6) |
|
|
|
|
|
def test_cer_whitespace(self): |
|
|
refs = ["were wolf"] |
|
|
preds = ["werewolf"] |
|
|
|
|
|
char_error_rate = cer.compute(predictions=preds, references=refs) |
|
|
self.assertTrue(abs(char_error_rate - 0.1111111) < 1e-6) |
|
|
|
|
|
refs = ["werewolf"] |
|
|
preds = ["weae wolf"] |
|
|
|
|
|
char_error_rate = cer.compute(predictions=preds, references=refs) |
|
|
self.assertTrue(abs(char_error_rate - 0.25) < 1e-6) |
|
|
|
|
|
|
|
|
refs = ["were wolf"] |
|
|
preds = ["were wolf"] |
|
|
|
|
|
char_error_rate = cer.compute(predictions=preds, references=refs) |
|
|
self.assertTrue(abs(char_error_rate - 0.0) < 1e-6) |
|
|
|
|
|
|
|
|
refs = ["were wolf"] |
|
|
preds = ["were wolf"] |
|
|
|
|
|
char_error_rate = cer.compute(predictions=preds, references=refs) |
|
|
self.assertTrue(abs(char_error_rate - 0.0) < 1e-6) |
|
|
|
|
|
def test_cer_sub(self): |
|
|
refs = ["werewolf"] |
|
|
preds = ["weaewolf"] |
|
|
|
|
|
char_error_rate = cer.compute(predictions=preds, references=refs) |
|
|
self.assertTrue(abs(char_error_rate - 0.125) < 1e-6) |
|
|
|
|
|
def test_cer_del(self): |
|
|
refs = ["werewolf"] |
|
|
preds = ["wereawolf"] |
|
|
|
|
|
char_error_rate = cer.compute(predictions=preds, references=refs) |
|
|
self.assertTrue(abs(char_error_rate - 0.125) < 1e-6) |
|
|
|
|
|
def test_cer_insert(self): |
|
|
refs = ["werewolf"] |
|
|
preds = ["wereolf"] |
|
|
|
|
|
char_error_rate = cer.compute(predictions=preds, references=refs) |
|
|
self.assertTrue(abs(char_error_rate - 0.125) < 1e-6) |
|
|
|
|
|
def test_cer_equal(self): |
|
|
refs = ["werewolf"] |
|
|
char_error_rate = cer.compute(predictions=refs, references=refs) |
|
|
self.assertEqual(char_error_rate, 0.0) |
|
|
|
|
|
def test_cer_list_of_seqs(self): |
|
|
refs = ["werewolf", "I am your father"] |
|
|
char_error_rate = cer.compute(predictions=refs, references=refs) |
|
|
self.assertEqual(char_error_rate, 0.0) |
|
|
|
|
|
refs = ["werewolf", "I am your father", "doge"] |
|
|
preds = ["werxwolf", "I am your father", "doge"] |
|
|
|
|
|
char_error_rate = cer.compute(predictions=preds, references=refs) |
|
|
self.assertTrue(abs(char_error_rate - 0.03571428) < 1e-6) |
|
|
|
|
|
def test_correlated_sentences(self): |
|
|
refs = ["My hovercraft", "is full of eels"] |
|
|
preds = ["My hovercraft is full", " of eels"] |
|
|
|
|
|
|
|
|
|
|
|
char_error_rate = cer.compute(predictions=preds, references=refs, concatenate_texts=True) |
|
|
self.assertTrue(abs(char_error_rate - 0.071428) < 1e-6) |
|
|
|
|
|
def test_cer_unicode(self): |
|
|
refs = ["我能吞下玻璃而不伤身体"] |
|
|
preds = [" 能吞虾玻璃而 不霜身体啦"] |
|
|
|
|
|
char_error_rate = cer.compute(predictions=preds, references=refs) |
|
|
self.assertTrue(abs(char_error_rate - 0.4545454545) < 1e-6) |
|
|
|
|
|
refs = ["我能吞下玻璃", "而不伤身体"] |
|
|
preds = ["我 能 吞 下 玻 璃", "而不伤身体"] |
|
|
|
|
|
char_error_rate = cer.compute(predictions=preds, references=refs) |
|
|
self.assertTrue(abs(char_error_rate - 0.454545454545) < 1e-6) |
|
|
|
|
|
refs = ["我能吞下玻璃而不伤身体"] |
|
|
char_error_rate = cer.compute(predictions=refs, references=refs) |
|
|
self.assertFalse(char_error_rate, 0.0) |
|
|
|
|
|
def test_cer_empty(self): |
|
|
refs = [""] |
|
|
preds = ["Hypothesis"] |
|
|
with self.assertRaises(ValueError): |
|
|
cer.compute(predictions=preds, references=refs) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
unittest.main() |
|
|
|