Agri_AWWER_Toolkit / tests /test_awwer.py
15laddoo's picture
Upload folder using huggingface_hub
1b3d38f verified
"""Tests for the agri_awwer package."""
import json
import math
from agri_awwer import (
clean_text,
align_words_dp,
parse_word_weights,
calculate_awwer,
calculate_awwer_components,
calculate_awwer_from_string,
calculate_wer,
calculate_cer,
calculate_mer,
get_word_weight,
)
# ---------------------------------------------------------------------------
# clean_text
# ---------------------------------------------------------------------------
class TestCleanText:
def test_basic_normalization(self):
assert clean_text("Hello, World!") == "hello world"
def test_empty_input(self):
assert clean_text("") == ""
assert clean_text(None) == ""
def test_punctuation_removal(self):
assert clean_text("gehun, aur makka.") == "gehun aur makka"
def test_whitespace_collapse(self):
assert clean_text(" gehun mein keet ") == "gehun mein keet"
def test_nan_handling(self):
assert clean_text(float("nan")) == ""
# ---------------------------------------------------------------------------
# align_words_dp
# ---------------------------------------------------------------------------
class TestAlignWordsDP:
def test_identical(self):
ops = align_words_dp(["a", "b", "c"], ["a", "b", "c"])
assert all(op[0] == "match" for op in ops)
def test_substitution(self):
ops = align_words_dp(["a", "b"], ["a", "x"])
types = [op[0] for op in ops]
assert types == ["match", "sub"]
def test_deletion(self):
ops = align_words_dp(["a", "b", "c"], ["a", "c"])
types = [op[0] for op in ops]
assert "del" in types
assert sum(1 for t in types if t == "match") == 2
def test_insertion(self):
ops = align_words_dp(["a", "c"], ["a", "b", "c"])
types = [op[0] for op in ops]
assert "ins" in types
assert sum(1 for t in types if t == "match") == 2
def test_empty_ref(self):
ops = align_words_dp([], ["a", "b"])
assert all(op[0] == "ins" for op in ops)
def test_empty_hyp(self):
ops = align_words_dp(["a", "b"], [])
assert all(op[0] == "del" for op in ops)
# ---------------------------------------------------------------------------
# parse_word_weights
# ---------------------------------------------------------------------------
class TestParseWordWeights:
def test_json_string(self):
s = json.dumps([["gehun", 4], ["keet", 3]])
w = parse_word_weights(s)
assert w == {"gehun": 4.0, "keet": 3.0}
def test_empty(self):
assert parse_word_weights("") == {}
assert parse_word_weights(None) == {}
def test_invalid_json(self):
assert parse_word_weights("not json") == {}
def test_list_input(self):
w = parse_word_weights([["a", 2], ["b", 3]])
assert w == {"a": 2.0, "b": 3.0}
# ---------------------------------------------------------------------------
# get_word_weight
# ---------------------------------------------------------------------------
class TestGetWordWeight:
def test_exact_match(self):
assert get_word_weight("gehun", {"gehun": 4.0}) == 4.0
def test_case_insensitive(self):
assert get_word_weight("Gehun", {"gehun": 4.0}) == 4.0
def test_default(self):
assert get_word_weight("unknown", {"gehun": 4.0}, default_weight=1.0) == 1.0
def test_empty(self):
assert get_word_weight("", {}) == 1.0
# ---------------------------------------------------------------------------
# calculate_awwer
# ---------------------------------------------------------------------------
class TestCalculateAWWER:
def setup_method(self):
self.weights = {
"gehun": 4.0,
"keet": 4.0,
"mitti": 3.0,
"gaon": 1.0,
}
def test_perfect_match(self):
ref = "gehun mein keet laga hai"
assert calculate_awwer(ref, ref, self.weights) == 0.0
def test_high_weight_error(self):
ref = "gehun mein keet laga hai"
hyp = "gaon mein keet laga hai"
awwer = calculate_awwer(ref, hyp, self.weights)
wer = calculate_wer(ref, hyp)
# AWWER should be > WER because gehun (weight 4) was substituted
assert awwer is not None
assert wer is not None
assert awwer > wer
def test_none_on_empty_ref(self):
assert calculate_awwer("", "something", self.weights) is None
assert calculate_awwer(None, "something", self.weights) is None
def test_all_deletions(self):
ref = "gehun keet"
hyp = ""
awwer = calculate_awwer(ref, hyp, self.weights)
# All reference words deleted → error_weight == total_weight → AWWER = 1.0
assert awwer == 1.0
# ---------------------------------------------------------------------------
# calculate_awwer_components
# ---------------------------------------------------------------------------
class TestCalculateAWWERComponents:
def test_breakdown(self):
weights = {"gehun": 4.0, "keet": 4.0}
ref = "gehun mein keet"
hyp = "gaon mein keet"
result = calculate_awwer_components(ref, hyp, weights)
assert result["n_substitutions"] == 1
assert result["n_deletions"] == 0
assert result["n_insertions"] == 0
assert len(result["high_weight_errors"]) == 1
assert result["high_weight_errors"][0]["ref_word"] == "gehun"
# ---------------------------------------------------------------------------
# calculate_awwer_from_string
# ---------------------------------------------------------------------------
class TestCalculateAWWERFromString:
def test_json_weights(self):
weights_json = json.dumps([["gehun", 4], ["keet", 4]])
ref = "gehun mein keet"
awwer = calculate_awwer_from_string(ref, ref, weights_json)
assert awwer == 0.0
# ---------------------------------------------------------------------------
# calculate_wer / calculate_cer / calculate_mer
# ---------------------------------------------------------------------------
class TestStandardMetrics:
def test_wer_perfect(self):
assert calculate_wer("hello world", "hello world") == 0.0
def test_wer_all_wrong(self):
wer = calculate_wer("a b c", "x y z")
assert wer == 1.0
def test_wer_none_on_empty_ref(self):
assert calculate_wer("", "hello") is None
def test_wer_empty_hyp(self):
assert calculate_wer("hello world", "") == 1.0
def test_cer_perfect(self):
assert calculate_cer("hello", "hello") == 0.0
def test_cer_nonzero(self):
cer = calculate_cer("abc", "axc")
assert cer is not None
assert cer > 0
def test_mer_perfect(self):
assert calculate_mer("hello world", "hello world") == 0.0
def test_mer_bounds(self):
mer = calculate_mer("a b c", "x y z")
assert mer is not None
assert 0.0 <= mer <= 1.0
def test_nan_handling(self):
assert calculate_wer(float("nan"), "hello") is None
assert calculate_cer(float("nan"), "hello") is None
assert calculate_mer(float("nan"), "hello") is None