SinCode / core /scorer.py
Kalana's picture
Refactor to core/ package: softmax MLM normalization, ambiguity handling, context-aware English detection (37/40 = 92.5%)
9906dbd
"""
Data-driven candidate scorer combining MLM, fidelity, and rank signals.
"""
import math
from dataclasses import dataclass, field
from typing import List
from core.constants import (
W_MLM, W_FIDELITY, W_RANK,
FIDELITY_SCALE, DICT_FIDELITY_DAMP,
SINHALA_VIRAMA, ZWJ,
)
@dataclass
class ScoredCandidate:
"""Holds a candidate word and its scoring breakdown."""
text: str
mlm_score: float = 0.0
fidelity_score: float = 0.0
rank_score: float = 0.0
combined_score: float = 0.0
is_english: bool = False
@dataclass
class WordDiagnostic:
"""Structured per-word diagnostics for evaluation and error analysis."""
step_index: int
input_word: str
rule_output: str
selected_candidate: str
beam_score: float
candidate_breakdown: List[ScoredCandidate]
class CandidateScorer:
"""
Data-driven replacement for the old hardcoded penalty table.
Combines three probabilistic signals to rank candidates:
1. **MLM Score** (weight Ξ± = 0.55)
Contextual fit from XLM-RoBERTa masked language model.
2. **Source-Aware Fidelity** (weight Ξ² = 0.45)
English candidates matching input β†’ 0.0 (user intent).
Dictionary candidates β†’ damped Levenshtein to rule output.
Rule-only outputs β†’ penalised by virama/skeleton density.
Other β†’ full Levenshtein distance to rule output.
3. **Rank Prior** (weight Ξ³ = 0.0, disabled)
Dictionary rank prior is disabled because entries are unordered.
"""
def __init__(
self,
w_mlm: float = W_MLM,
w_fidelity: float = W_FIDELITY,
w_rank: float = W_RANK,
fidelity_scale: float = FIDELITY_SCALE,
):
self.w_mlm = w_mlm
self.w_fidelity = w_fidelity
self.w_rank = w_rank
self.fidelity_scale = fidelity_scale
# ── Levenshtein distance (pure-Python, no dependencies) ──────────
@staticmethod
def levenshtein(s1: str, s2: str) -> int:
"""Compute the Levenshtein edit distance between two strings."""
if not s1:
return len(s2)
if not s2:
return len(s1)
m, n = len(s1), len(s2)
prev_row = list(range(n + 1))
for i in range(1, m + 1):
curr_row = [i] + [0] * n
for j in range(1, n + 1):
cost = 0 if s1[i - 1] == s2[j - 1] else 1
curr_row[j] = min(
prev_row[j] + 1, # deletion
curr_row[j - 1] + 1, # insertion
prev_row[j - 1] + cost, # substitution
)
prev_row = curr_row
return prev_row[n]
# ── Scoring components ───────────────────────────────────────────
def compute_fidelity(
self, candidate: str, rule_output: str,
original_input: str = "", is_from_dict: bool = False,
is_ambiguous: bool = False,
) -> float:
"""
Source-aware transliteration fidelity.
- **English matching input** β†’ 0.0 (user-intent preservation).
- **Dict + matches rule output** β†’ strong bonus (+2.0),
reduced to +0.5 when *is_ambiguous* (many dict candidates
with different meanings β†’ let MLM context decide).
- **Dict only** β†’ decaying bonus (1.0 down to 0.0 with distance).
- **Rule-only outputs not in dictionary** β†’ penalised by
consonant-skeleton density (high virama ratio = malformed).
- **Other** β†’ full Levenshtein distance to rule output.
"""
# 1. English candidate matching the original input word
if original_input and candidate.lower() == original_input.lower():
return 0.0
# 2. Dictionary-validated candidates
if is_from_dict:
if candidate == rule_output:
return 0.5 if is_ambiguous else 2.0
max_len = max(len(candidate), len(rule_output), 1)
norm_dist = self.levenshtein(candidate, rule_output) / max_len
return max(0.0, 1.0 - norm_dist * DICT_FIDELITY_DAMP)
# 3. Rule-only output (not validated by dictionary)
if candidate == rule_output:
bare_virama = sum(
1 for i, ch in enumerate(candidate)
if ch == SINHALA_VIRAMA
and (i + 1 >= len(candidate) or candidate[i + 1] != ZWJ)
)
density = bare_virama / max(len(candidate), 1)
return -density * self.fidelity_scale * 2
# 4. English word not matching input β€” uncertain
if candidate.isascii():
return -0.5
# 5. Sinhala candidate not from dictionary β€” distance penalty
max_len = max(len(candidate), len(rule_output), 1)
norm_dist = self.levenshtein(candidate, rule_output) / max_len
return -norm_dist * self.fidelity_scale
@staticmethod
def compute_rank_prior(rank: int, total: int) -> float:
"""Log-decay rank prior. First candidate β†’ 0.0; later ones decay."""
if total <= 1:
return 0.0
return math.log(1.0 / (rank + 1))
# ── Combined score ───────────────────────────────────────────────
def score(
self,
mlm_score: float,
candidate: str,
rule_output: str,
rank: int,
total_candidates: int,
is_english: bool = False,
original_input: str = "",
is_from_dict: bool = False,
is_ambiguous: bool = False,
) -> ScoredCandidate:
"""Return a :class:`ScoredCandidate` with full breakdown."""
fidelity = self.compute_fidelity(
candidate, rule_output, original_input, is_from_dict,
is_ambiguous,
)
rank_prior = self.compute_rank_prior(rank, total_candidates)
combined = (
self.w_mlm * mlm_score
+ self.w_fidelity * fidelity
+ self.w_rank * rank_prior
)
return ScoredCandidate(
text=candidate,
mlm_score=mlm_score,
fidelity_score=fidelity,
rank_score=rank_prior,
combined_score=combined,
is_english=is_english,
)