"""Tests for the agri_awwer package.""" import json import math from agri_awwer import ( clean_text, align_words_dp, parse_word_weights, calculate_awwer, calculate_awwer_components, calculate_awwer_from_string, calculate_wer, calculate_cer, calculate_mer, get_word_weight, ) # --------------------------------------------------------------------------- # clean_text # --------------------------------------------------------------------------- class TestCleanText: def test_basic_normalization(self): assert clean_text("Hello, World!") == "hello world" def test_empty_input(self): assert clean_text("") == "" assert clean_text(None) == "" def test_punctuation_removal(self): assert clean_text("gehun, aur makka.") == "gehun aur makka" def test_whitespace_collapse(self): assert clean_text(" gehun mein keet ") == "gehun mein keet" def test_nan_handling(self): assert clean_text(float("nan")) == "" # --------------------------------------------------------------------------- # align_words_dp # --------------------------------------------------------------------------- class TestAlignWordsDP: def test_identical(self): ops = align_words_dp(["a", "b", "c"], ["a", "b", "c"]) assert all(op[0] == "match" for op in ops) def test_substitution(self): ops = align_words_dp(["a", "b"], ["a", "x"]) types = [op[0] for op in ops] assert types == ["match", "sub"] def test_deletion(self): ops = align_words_dp(["a", "b", "c"], ["a", "c"]) types = [op[0] for op in ops] assert "del" in types assert sum(1 for t in types if t == "match") == 2 def test_insertion(self): ops = align_words_dp(["a", "c"], ["a", "b", "c"]) types = [op[0] for op in ops] assert "ins" in types assert sum(1 for t in types if t == "match") == 2 def test_empty_ref(self): ops = align_words_dp([], ["a", "b"]) assert all(op[0] == "ins" for op in ops) def test_empty_hyp(self): ops = align_words_dp(["a", "b"], []) assert all(op[0] == "del" for op in ops) # --------------------------------------------------------------------------- # parse_word_weights # --------------------------------------------------------------------------- class TestParseWordWeights: def test_json_string(self): s = json.dumps([["gehun", 4], ["keet", 3]]) w = parse_word_weights(s) assert w == {"gehun": 4.0, "keet": 3.0} def test_empty(self): assert parse_word_weights("") == {} assert parse_word_weights(None) == {} def test_invalid_json(self): assert parse_word_weights("not json") == {} def test_list_input(self): w = parse_word_weights([["a", 2], ["b", 3]]) assert w == {"a": 2.0, "b": 3.0} # --------------------------------------------------------------------------- # get_word_weight # --------------------------------------------------------------------------- class TestGetWordWeight: def test_exact_match(self): assert get_word_weight("gehun", {"gehun": 4.0}) == 4.0 def test_case_insensitive(self): assert get_word_weight("Gehun", {"gehun": 4.0}) == 4.0 def test_default(self): assert get_word_weight("unknown", {"gehun": 4.0}, default_weight=1.0) == 1.0 def test_empty(self): assert get_word_weight("", {}) == 1.0 # --------------------------------------------------------------------------- # calculate_awwer # --------------------------------------------------------------------------- class TestCalculateAWWER: def setup_method(self): self.weights = { "gehun": 4.0, "keet": 4.0, "mitti": 3.0, "gaon": 1.0, } def test_perfect_match(self): ref = "gehun mein keet laga hai" assert calculate_awwer(ref, ref, self.weights) == 0.0 def test_high_weight_error(self): ref = "gehun mein keet laga hai" hyp = "gaon mein keet laga hai" awwer = calculate_awwer(ref, hyp, self.weights) wer = calculate_wer(ref, hyp) # AWWER should be > WER because gehun (weight 4) was substituted assert awwer is not None assert wer is not None assert awwer > wer def test_none_on_empty_ref(self): assert calculate_awwer("", "something", self.weights) is None assert calculate_awwer(None, "something", self.weights) is None def test_all_deletions(self): ref = "gehun keet" hyp = "" awwer = calculate_awwer(ref, hyp, self.weights) # All reference words deleted → error_weight == total_weight → AWWER = 1.0 assert awwer == 1.0 # --------------------------------------------------------------------------- # calculate_awwer_components # --------------------------------------------------------------------------- class TestCalculateAWWERComponents: def test_breakdown(self): weights = {"gehun": 4.0, "keet": 4.0} ref = "gehun mein keet" hyp = "gaon mein keet" result = calculate_awwer_components(ref, hyp, weights) assert result["n_substitutions"] == 1 assert result["n_deletions"] == 0 assert result["n_insertions"] == 0 assert len(result["high_weight_errors"]) == 1 assert result["high_weight_errors"][0]["ref_word"] == "gehun" # --------------------------------------------------------------------------- # calculate_awwer_from_string # --------------------------------------------------------------------------- class TestCalculateAWWERFromString: def test_json_weights(self): weights_json = json.dumps([["gehun", 4], ["keet", 4]]) ref = "gehun mein keet" awwer = calculate_awwer_from_string(ref, ref, weights_json) assert awwer == 0.0 # --------------------------------------------------------------------------- # calculate_wer / calculate_cer / calculate_mer # --------------------------------------------------------------------------- class TestStandardMetrics: def test_wer_perfect(self): assert calculate_wer("hello world", "hello world") == 0.0 def test_wer_all_wrong(self): wer = calculate_wer("a b c", "x y z") assert wer == 1.0 def test_wer_none_on_empty_ref(self): assert calculate_wer("", "hello") is None def test_wer_empty_hyp(self): assert calculate_wer("hello world", "") == 1.0 def test_cer_perfect(self): assert calculate_cer("hello", "hello") == 0.0 def test_cer_nonzero(self): cer = calculate_cer("abc", "axc") assert cer is not None assert cer > 0 def test_mer_perfect(self): assert calculate_mer("hello world", "hello world") == 0.0 def test_mer_bounds(self): mer = calculate_mer("a b c", "x y z") assert mer is not None assert 0.0 <= mer <= 1.0 def test_nan_handling(self): assert calculate_wer(float("nan"), "hello") is None assert calculate_cer(float("nan"), "hello") is None assert calculate_mer(float("nan"), "hello") is None