File size: 6,403 Bytes
9906dbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""
Data-driven candidate scorer combining MLM, fidelity, and rank signals.
"""

import math
from dataclasses import dataclass, field
from typing import List

from core.constants import (
    W_MLM, W_FIDELITY, W_RANK,
    FIDELITY_SCALE, DICT_FIDELITY_DAMP,
    SINHALA_VIRAMA, ZWJ,
)


@dataclass
class ScoredCandidate:
    """Holds a candidate word and its scoring breakdown."""
    text: str
    mlm_score: float = 0.0
    fidelity_score: float = 0.0
    rank_score: float = 0.0
    combined_score: float = 0.0
    is_english: bool = False


@dataclass
class WordDiagnostic:
    """Structured per-word diagnostics for evaluation and error analysis."""
    step_index: int
    input_word: str
    rule_output: str
    selected_candidate: str
    beam_score: float
    candidate_breakdown: List[ScoredCandidate]


class CandidateScorer:
    """
    Data-driven replacement for the old hardcoded penalty table.

    Combines three probabilistic signals to rank candidates:

    1. **MLM Score** (weight Ξ± = 0.55)
       Contextual fit from XLM-RoBERTa masked language model.

    2. **Source-Aware Fidelity** (weight Ξ² = 0.45)
       English candidates matching input β†’ 0.0 (user intent).
       Dictionary candidates β†’ damped Levenshtein to rule output.
       Rule-only outputs β†’ penalised by virama/skeleton density.
       Other β†’ full Levenshtein distance to rule output.

    3. **Rank Prior** (weight Ξ³ = 0.0, disabled)
       Dictionary rank prior is disabled because entries are unordered.
    """

    def __init__(
        self,
        w_mlm: float = W_MLM,
        w_fidelity: float = W_FIDELITY,
        w_rank: float = W_RANK,
        fidelity_scale: float = FIDELITY_SCALE,
    ):
        self.w_mlm = w_mlm
        self.w_fidelity = w_fidelity
        self.w_rank = w_rank
        self.fidelity_scale = fidelity_scale

    # ── Levenshtein distance (pure-Python, no dependencies) ──────────

    @staticmethod
    def levenshtein(s1: str, s2: str) -> int:
        """Compute the Levenshtein edit distance between two strings."""
        if not s1:
            return len(s2)
        if not s2:
            return len(s1)

        m, n = len(s1), len(s2)
        prev_row = list(range(n + 1))

        for i in range(1, m + 1):
            curr_row = [i] + [0] * n
            for j in range(1, n + 1):
                cost = 0 if s1[i - 1] == s2[j - 1] else 1
                curr_row[j] = min(
                    prev_row[j] + 1,       # deletion
                    curr_row[j - 1] + 1,    # insertion
                    prev_row[j - 1] + cost, # substitution
                )
            prev_row = curr_row

        return prev_row[n]

    # ── Scoring components ───────────────────────────────────────────

    def compute_fidelity(
        self, candidate: str, rule_output: str,
        original_input: str = "", is_from_dict: bool = False,
        is_ambiguous: bool = False,
    ) -> float:
        """
        Source-aware transliteration fidelity.

        - **English matching input** β†’ 0.0  (user-intent preservation).
        - **Dict + matches rule output** β†’ strong bonus (+2.0),
          reduced to +0.5 when *is_ambiguous* (many dict candidates
          with different meanings β†’ let MLM context decide).
        - **Dict only** β†’ decaying bonus (1.0 down to 0.0 with distance).
        - **Rule-only outputs not in dictionary** β†’ penalised by
          consonant-skeleton density (high virama ratio = malformed).
        - **Other** β†’ full Levenshtein distance to rule output.
        """
        # 1. English candidate matching the original input word
        if original_input and candidate.lower() == original_input.lower():
            return 0.0

        # 2. Dictionary-validated candidates
        if is_from_dict:
            if candidate == rule_output:
                return 0.5 if is_ambiguous else 2.0
            max_len = max(len(candidate), len(rule_output), 1)
            norm_dist = self.levenshtein(candidate, rule_output) / max_len
            return max(0.0, 1.0 - norm_dist * DICT_FIDELITY_DAMP)

        # 3. Rule-only output (not validated by dictionary)
        if candidate == rule_output:
            bare_virama = sum(
                1 for i, ch in enumerate(candidate)
                if ch == SINHALA_VIRAMA
                and (i + 1 >= len(candidate) or candidate[i + 1] != ZWJ)
            )
            density = bare_virama / max(len(candidate), 1)
            return -density * self.fidelity_scale * 2

        # 4. English word not matching input β€” uncertain
        if candidate.isascii():
            return -0.5

        # 5. Sinhala candidate not from dictionary β€” distance penalty
        max_len = max(len(candidate), len(rule_output), 1)
        norm_dist = self.levenshtein(candidate, rule_output) / max_len
        return -norm_dist * self.fidelity_scale

    @staticmethod
    def compute_rank_prior(rank: int, total: int) -> float:
        """Log-decay rank prior. First candidate β†’ 0.0; later ones decay."""
        if total <= 1:
            return 0.0
        return math.log(1.0 / (rank + 1))

    # ── Combined score ───────────────────────────────────────────────

    def score(
        self,
        mlm_score: float,
        candidate: str,
        rule_output: str,
        rank: int,
        total_candidates: int,
        is_english: bool = False,
        original_input: str = "",
        is_from_dict: bool = False,
        is_ambiguous: bool = False,
    ) -> ScoredCandidate:
        """Return a :class:`ScoredCandidate` with full breakdown."""
        fidelity = self.compute_fidelity(
            candidate, rule_output, original_input, is_from_dict,
            is_ambiguous,
        )
        rank_prior = self.compute_rank_prior(rank, total_candidates)

        combined = (
            self.w_mlm * mlm_score
            + self.w_fidelity * fidelity
            + self.w_rank * rank_prior
        )

        return ScoredCandidate(
            text=candidate,
            mlm_score=mlm_score,
            fidelity_score=fidelity,
            rank_score=rank_prior,
            combined_score=combined,
            is_english=is_english,
        )