File size: 7,236 Bytes
1b3d38f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
"""Tests for the agri_awwer package."""

import json
import math

from agri_awwer import (
    clean_text,
    align_words_dp,
    parse_word_weights,
    calculate_awwer,
    calculate_awwer_components,
    calculate_awwer_from_string,
    calculate_wer,
    calculate_cer,
    calculate_mer,
    get_word_weight,
)


# ---------------------------------------------------------------------------
# clean_text
# ---------------------------------------------------------------------------

class TestCleanText:
    def test_basic_normalization(self):
        assert clean_text("Hello, World!") == "hello world"

    def test_empty_input(self):
        assert clean_text("") == ""
        assert clean_text(None) == ""

    def test_punctuation_removal(self):
        assert clean_text("gehun, aur makka.") == "gehun aur makka"

    def test_whitespace_collapse(self):
        assert clean_text("  gehun   mein   keet  ") == "gehun mein keet"

    def test_nan_handling(self):
        assert clean_text(float("nan")) == ""


# ---------------------------------------------------------------------------
# align_words_dp
# ---------------------------------------------------------------------------

class TestAlignWordsDP:
    def test_identical(self):
        ops = align_words_dp(["a", "b", "c"], ["a", "b", "c"])
        assert all(op[0] == "match" for op in ops)

    def test_substitution(self):
        ops = align_words_dp(["a", "b"], ["a", "x"])
        types = [op[0] for op in ops]
        assert types == ["match", "sub"]

    def test_deletion(self):
        ops = align_words_dp(["a", "b", "c"], ["a", "c"])
        types = [op[0] for op in ops]
        assert "del" in types
        assert sum(1 for t in types if t == "match") == 2

    def test_insertion(self):
        ops = align_words_dp(["a", "c"], ["a", "b", "c"])
        types = [op[0] for op in ops]
        assert "ins" in types
        assert sum(1 for t in types if t == "match") == 2

    def test_empty_ref(self):
        ops = align_words_dp([], ["a", "b"])
        assert all(op[0] == "ins" for op in ops)

    def test_empty_hyp(self):
        ops = align_words_dp(["a", "b"], [])
        assert all(op[0] == "del" for op in ops)


# ---------------------------------------------------------------------------
# parse_word_weights
# ---------------------------------------------------------------------------

class TestParseWordWeights:
    def test_json_string(self):
        s = json.dumps([["gehun", 4], ["keet", 3]])
        w = parse_word_weights(s)
        assert w == {"gehun": 4.0, "keet": 3.0}

    def test_empty(self):
        assert parse_word_weights("") == {}
        assert parse_word_weights(None) == {}

    def test_invalid_json(self):
        assert parse_word_weights("not json") == {}

    def test_list_input(self):
        w = parse_word_weights([["a", 2], ["b", 3]])
        assert w == {"a": 2.0, "b": 3.0}


# ---------------------------------------------------------------------------
# get_word_weight
# ---------------------------------------------------------------------------

class TestGetWordWeight:
    def test_exact_match(self):
        assert get_word_weight("gehun", {"gehun": 4.0}) == 4.0

    def test_case_insensitive(self):
        assert get_word_weight("Gehun", {"gehun": 4.0}) == 4.0

    def test_default(self):
        assert get_word_weight("unknown", {"gehun": 4.0}, default_weight=1.0) == 1.0

    def test_empty(self):
        assert get_word_weight("", {}) == 1.0


# ---------------------------------------------------------------------------
# calculate_awwer
# ---------------------------------------------------------------------------

class TestCalculateAWWER:
    def setup_method(self):
        self.weights = {
            "gehun": 4.0,
            "keet": 4.0,
            "mitti": 3.0,
            "gaon": 1.0,
        }

    def test_perfect_match(self):
        ref = "gehun mein keet laga hai"
        assert calculate_awwer(ref, ref, self.weights) == 0.0

    def test_high_weight_error(self):
        ref = "gehun mein keet laga hai"
        hyp = "gaon mein keet laga hai"
        awwer = calculate_awwer(ref, hyp, self.weights)
        wer = calculate_wer(ref, hyp)
        # AWWER should be > WER because gehun (weight 4) was substituted
        assert awwer is not None
        assert wer is not None
        assert awwer > wer

    def test_none_on_empty_ref(self):
        assert calculate_awwer("", "something", self.weights) is None
        assert calculate_awwer(None, "something", self.weights) is None

    def test_all_deletions(self):
        ref = "gehun keet"
        hyp = ""
        awwer = calculate_awwer(ref, hyp, self.weights)
        # All reference words deleted → error_weight == total_weight → AWWER = 1.0
        assert awwer == 1.0


# ---------------------------------------------------------------------------
# calculate_awwer_components
# ---------------------------------------------------------------------------

class TestCalculateAWWERComponents:
    def test_breakdown(self):
        weights = {"gehun": 4.0, "keet": 4.0}
        ref = "gehun mein keet"
        hyp = "gaon mein keet"
        result = calculate_awwer_components(ref, hyp, weights)
        assert result["n_substitutions"] == 1
        assert result["n_deletions"] == 0
        assert result["n_insertions"] == 0
        assert len(result["high_weight_errors"]) == 1
        assert result["high_weight_errors"][0]["ref_word"] == "gehun"


# ---------------------------------------------------------------------------
# calculate_awwer_from_string
# ---------------------------------------------------------------------------

class TestCalculateAWWERFromString:
    def test_json_weights(self):
        weights_json = json.dumps([["gehun", 4], ["keet", 4]])
        ref = "gehun mein keet"
        awwer = calculate_awwer_from_string(ref, ref, weights_json)
        assert awwer == 0.0


# ---------------------------------------------------------------------------
# calculate_wer / calculate_cer / calculate_mer
# ---------------------------------------------------------------------------

class TestStandardMetrics:
    def test_wer_perfect(self):
        assert calculate_wer("hello world", "hello world") == 0.0

    def test_wer_all_wrong(self):
        wer = calculate_wer("a b c", "x y z")
        assert wer == 1.0

    def test_wer_none_on_empty_ref(self):
        assert calculate_wer("", "hello") is None

    def test_wer_empty_hyp(self):
        assert calculate_wer("hello world", "") == 1.0

    def test_cer_perfect(self):
        assert calculate_cer("hello", "hello") == 0.0

    def test_cer_nonzero(self):
        cer = calculate_cer("abc", "axc")
        assert cer is not None
        assert cer > 0

    def test_mer_perfect(self):
        assert calculate_mer("hello world", "hello world") == 0.0

    def test_mer_bounds(self):
        mer = calculate_mer("a b c", "x y z")
        assert mer is not None
        assert 0.0 <= mer <= 1.0

    def test_nan_handling(self):
        assert calculate_wer(float("nan"), "hello") is None
        assert calculate_cer(float("nan"), "hello") is None
        assert calculate_mer(float("nan"), "hello") is None