| """ |
| Data-driven candidate scorer combining MLM, fidelity, and rank signals. |
| """ |
|
|
| import math |
| from dataclasses import dataclass, field |
| from typing import List |
|
|
| from core.constants import ( |
| W_MLM, W_FIDELITY, W_RANK, |
| FIDELITY_SCALE, DICT_FIDELITY_DAMP, |
| SINHALA_VIRAMA, ZWJ, |
| ) |
|
|
|
|
| @dataclass |
| class ScoredCandidate: |
| """Holds a candidate word and its scoring breakdown.""" |
| text: str |
| mlm_score: float = 0.0 |
| fidelity_score: float = 0.0 |
| rank_score: float = 0.0 |
| combined_score: float = 0.0 |
| is_english: bool = False |
|
|
|
|
| @dataclass |
| class WordDiagnostic: |
| """Structured per-word diagnostics for evaluation and error analysis.""" |
| step_index: int |
| input_word: str |
| rule_output: str |
| selected_candidate: str |
| beam_score: float |
| candidate_breakdown: List[ScoredCandidate] |
|
|
|
|
| class CandidateScorer: |
| """ |
| Data-driven replacement for the old hardcoded penalty table. |
| |
| Combines three probabilistic signals to rank candidates: |
| |
| 1. **MLM Score** (weight Ξ± = 0.55) |
| Contextual fit from XLM-RoBERTa masked language model. |
| |
| 2. **Source-Aware Fidelity** (weight Ξ² = 0.45) |
| English candidates matching input β 0.0 (user intent). |
| Dictionary candidates β damped Levenshtein to rule output. |
| Rule-only outputs β penalised by virama/skeleton density. |
| Other β full Levenshtein distance to rule output. |
| |
| 3. **Rank Prior** (weight Ξ³ = 0.0, disabled) |
| Dictionary rank prior is disabled because entries are unordered. |
| """ |
|
|
| def __init__( |
| self, |
| w_mlm: float = W_MLM, |
| w_fidelity: float = W_FIDELITY, |
| w_rank: float = W_RANK, |
| fidelity_scale: float = FIDELITY_SCALE, |
| ): |
| self.w_mlm = w_mlm |
| self.w_fidelity = w_fidelity |
| self.w_rank = w_rank |
| self.fidelity_scale = fidelity_scale |
|
|
| |
|
|
| @staticmethod |
| def levenshtein(s1: str, s2: str) -> int: |
| """Compute the Levenshtein edit distance between two strings.""" |
| if not s1: |
| return len(s2) |
| if not s2: |
| return len(s1) |
|
|
| m, n = len(s1), len(s2) |
| prev_row = list(range(n + 1)) |
|
|
| for i in range(1, m + 1): |
| curr_row = [i] + [0] * n |
| for j in range(1, n + 1): |
| cost = 0 if s1[i - 1] == s2[j - 1] else 1 |
| curr_row[j] = min( |
| prev_row[j] + 1, |
| curr_row[j - 1] + 1, |
| prev_row[j - 1] + cost, |
| ) |
| prev_row = curr_row |
|
|
| return prev_row[n] |
|
|
| |
|
|
| def compute_fidelity( |
| self, candidate: str, rule_output: str, |
| original_input: str = "", is_from_dict: bool = False, |
| is_ambiguous: bool = False, |
| ) -> float: |
| """ |
| Source-aware transliteration fidelity. |
| |
| - **English matching input** β 0.0 (user-intent preservation). |
| - **Dict + matches rule output** β strong bonus (+2.0), |
| reduced to +0.5 when *is_ambiguous* (many dict candidates |
| with different meanings β let MLM context decide). |
| - **Dict only** β decaying bonus (1.0 down to 0.0 with distance). |
| - **Rule-only outputs not in dictionary** β penalised by |
| consonant-skeleton density (high virama ratio = malformed). |
| - **Other** β full Levenshtein distance to rule output. |
| """ |
| |
| if original_input and candidate.lower() == original_input.lower(): |
| return 0.0 |
|
|
| |
| if is_from_dict: |
| if candidate == rule_output: |
| return 0.5 if is_ambiguous else 2.0 |
| max_len = max(len(candidate), len(rule_output), 1) |
| norm_dist = self.levenshtein(candidate, rule_output) / max_len |
| return max(0.0, 1.0 - norm_dist * DICT_FIDELITY_DAMP) |
|
|
| |
| if candidate == rule_output: |
| bare_virama = sum( |
| 1 for i, ch in enumerate(candidate) |
| if ch == SINHALA_VIRAMA |
| and (i + 1 >= len(candidate) or candidate[i + 1] != ZWJ) |
| ) |
| density = bare_virama / max(len(candidate), 1) |
| return -density * self.fidelity_scale * 2 |
|
|
| |
| if candidate.isascii(): |
| return -0.5 |
|
|
| |
| max_len = max(len(candidate), len(rule_output), 1) |
| norm_dist = self.levenshtein(candidate, rule_output) / max_len |
| return -norm_dist * self.fidelity_scale |
|
|
| @staticmethod |
| def compute_rank_prior(rank: int, total: int) -> float: |
| """Log-decay rank prior. First candidate β 0.0; later ones decay.""" |
| if total <= 1: |
| return 0.0 |
| return math.log(1.0 / (rank + 1)) |
|
|
| |
|
|
| def score( |
| self, |
| mlm_score: float, |
| candidate: str, |
| rule_output: str, |
| rank: int, |
| total_candidates: int, |
| is_english: bool = False, |
| original_input: str = "", |
| is_from_dict: bool = False, |
| is_ambiguous: bool = False, |
| ) -> ScoredCandidate: |
| """Return a :class:`ScoredCandidate` with full breakdown.""" |
| fidelity = self.compute_fidelity( |
| candidate, rule_output, original_input, is_from_dict, |
| is_ambiguous, |
| ) |
| rank_prior = self.compute_rank_prior(rank, total_candidates) |
|
|
| combined = ( |
| self.w_mlm * mlm_score |
| + self.w_fidelity * fidelity |
| + self.w_rank * rank_prior |
| ) |
|
|
| return ScoredCandidate( |
| text=candidate, |
| mlm_score=mlm_score, |
| fidelity_score=fidelity, |
| rank_score=rank_prior, |
| combined_score=combined, |
| is_english=is_english, |
| ) |
|
|