| | """Tests for the agri_awwer package.""" |
| |
|
| | import json |
| | import math |
| |
|
| | from agri_awwer import ( |
| | clean_text, |
| | align_words_dp, |
| | parse_word_weights, |
| | calculate_awwer, |
| | calculate_awwer_components, |
| | calculate_awwer_from_string, |
| | calculate_wer, |
| | calculate_cer, |
| | calculate_mer, |
| | get_word_weight, |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class TestCleanText: |
| | def test_basic_normalization(self): |
| | assert clean_text("Hello, World!") == "hello world" |
| |
|
| | def test_empty_input(self): |
| | assert clean_text("") == "" |
| | assert clean_text(None) == "" |
| |
|
| | def test_punctuation_removal(self): |
| | assert clean_text("gehun, aur makka.") == "gehun aur makka" |
| |
|
| | def test_whitespace_collapse(self): |
| | assert clean_text(" gehun mein keet ") == "gehun mein keet" |
| |
|
| | def test_nan_handling(self): |
| | assert clean_text(float("nan")) == "" |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class TestAlignWordsDP: |
| | def test_identical(self): |
| | ops = align_words_dp(["a", "b", "c"], ["a", "b", "c"]) |
| | assert all(op[0] == "match" for op in ops) |
| |
|
| | def test_substitution(self): |
| | ops = align_words_dp(["a", "b"], ["a", "x"]) |
| | types = [op[0] for op in ops] |
| | assert types == ["match", "sub"] |
| |
|
| | def test_deletion(self): |
| | ops = align_words_dp(["a", "b", "c"], ["a", "c"]) |
| | types = [op[0] for op in ops] |
| | assert "del" in types |
| | assert sum(1 for t in types if t == "match") == 2 |
| |
|
| | def test_insertion(self): |
| | ops = align_words_dp(["a", "c"], ["a", "b", "c"]) |
| | types = [op[0] for op in ops] |
| | assert "ins" in types |
| | assert sum(1 for t in types if t == "match") == 2 |
| |
|
| | def test_empty_ref(self): |
| | ops = align_words_dp([], ["a", "b"]) |
| | assert all(op[0] == "ins" for op in ops) |
| |
|
| | def test_empty_hyp(self): |
| | ops = align_words_dp(["a", "b"], []) |
| | assert all(op[0] == "del" for op in ops) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class TestParseWordWeights: |
| | def test_json_string(self): |
| | s = json.dumps([["gehun", 4], ["keet", 3]]) |
| | w = parse_word_weights(s) |
| | assert w == {"gehun": 4.0, "keet": 3.0} |
| |
|
| | def test_empty(self): |
| | assert parse_word_weights("") == {} |
| | assert parse_word_weights(None) == {} |
| |
|
| | def test_invalid_json(self): |
| | assert parse_word_weights("not json") == {} |
| |
|
| | def test_list_input(self): |
| | w = parse_word_weights([["a", 2], ["b", 3]]) |
| | assert w == {"a": 2.0, "b": 3.0} |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class TestGetWordWeight: |
| | def test_exact_match(self): |
| | assert get_word_weight("gehun", {"gehun": 4.0}) == 4.0 |
| |
|
| | def test_case_insensitive(self): |
| | assert get_word_weight("Gehun", {"gehun": 4.0}) == 4.0 |
| |
|
| | def test_default(self): |
| | assert get_word_weight("unknown", {"gehun": 4.0}, default_weight=1.0) == 1.0 |
| |
|
| | def test_empty(self): |
| | assert get_word_weight("", {}) == 1.0 |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class TestCalculateAWWER: |
| | def setup_method(self): |
| | self.weights = { |
| | "gehun": 4.0, |
| | "keet": 4.0, |
| | "mitti": 3.0, |
| | "gaon": 1.0, |
| | } |
| |
|
| | def test_perfect_match(self): |
| | ref = "gehun mein keet laga hai" |
| | assert calculate_awwer(ref, ref, self.weights) == 0.0 |
| |
|
| | def test_high_weight_error(self): |
| | ref = "gehun mein keet laga hai" |
| | hyp = "gaon mein keet laga hai" |
| | awwer = calculate_awwer(ref, hyp, self.weights) |
| | wer = calculate_wer(ref, hyp) |
| | |
| | assert awwer is not None |
| | assert wer is not None |
| | assert awwer > wer |
| |
|
| | def test_none_on_empty_ref(self): |
| | assert calculate_awwer("", "something", self.weights) is None |
| | assert calculate_awwer(None, "something", self.weights) is None |
| |
|
| | def test_all_deletions(self): |
| | ref = "gehun keet" |
| | hyp = "" |
| | awwer = calculate_awwer(ref, hyp, self.weights) |
| | |
| | assert awwer == 1.0 |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class TestCalculateAWWERComponents: |
| | def test_breakdown(self): |
| | weights = {"gehun": 4.0, "keet": 4.0} |
| | ref = "gehun mein keet" |
| | hyp = "gaon mein keet" |
| | result = calculate_awwer_components(ref, hyp, weights) |
| | assert result["n_substitutions"] == 1 |
| | assert result["n_deletions"] == 0 |
| | assert result["n_insertions"] == 0 |
| | assert len(result["high_weight_errors"]) == 1 |
| | assert result["high_weight_errors"][0]["ref_word"] == "gehun" |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class TestCalculateAWWERFromString: |
| | def test_json_weights(self): |
| | weights_json = json.dumps([["gehun", 4], ["keet", 4]]) |
| | ref = "gehun mein keet" |
| | awwer = calculate_awwer_from_string(ref, ref, weights_json) |
| | assert awwer == 0.0 |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class TestStandardMetrics: |
| | def test_wer_perfect(self): |
| | assert calculate_wer("hello world", "hello world") == 0.0 |
| |
|
| | def test_wer_all_wrong(self): |
| | wer = calculate_wer("a b c", "x y z") |
| | assert wer == 1.0 |
| |
|
| | def test_wer_none_on_empty_ref(self): |
| | assert calculate_wer("", "hello") is None |
| |
|
| | def test_wer_empty_hyp(self): |
| | assert calculate_wer("hello world", "") == 1.0 |
| |
|
| | def test_cer_perfect(self): |
| | assert calculate_cer("hello", "hello") == 0.0 |
| |
|
| | def test_cer_nonzero(self): |
| | cer = calculate_cer("abc", "axc") |
| | assert cer is not None |
| | assert cer > 0 |
| |
|
| | def test_mer_perfect(self): |
| | assert calculate_mer("hello world", "hello world") == 0.0 |
| |
|
| | def test_mer_bounds(self): |
| | mer = calculate_mer("a b c", "x y z") |
| | assert mer is not None |
| | assert 0.0 <= mer <= 1.0 |
| |
|
| | def test_nan_handling(self): |
| | assert calculate_wer(float("nan"), "hello") is None |
| | assert calculate_cer(float("nan"), "hello") is None |
| | assert calculate_mer(float("nan"), "hello") is None |
| |
|