File size: 4,301 Bytes
f6f45d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
mBart50-based Sentence Transliterator for SinCode v3.

Full-sentence Singlish β†’ Sinhala transliteration.
Unlike the ByT5 word-by-word pipeline, mBart50 operates on the whole input
sentence and produces fully Sinhalized output β€” no English words are retained.

Use-case: "mn heta business ekak start karanawa"
       β†’ "ࢸࢱ් ΰ·„ΰ·™ΰΆ§ ΰ·€ΰ·Šβ€ΰΆΊΰ·ΰΆ΄ΰ·ΰΆ»ΰΆΊΰΆšΰ·Š ࢴࢧࢱ් ΰΆœΰΆ±ΰ·ŠΰΆ±ΰ·€ΰ·"
"""

from __future__ import annotations

import json
import logging
import re
from pathlib import Path
from typing import Optional

import torch
from transformers import MBart50Tokenizer, MBartForConditionalGeneration

from core.constants import DEFAULT_MBART_MODEL

logger = logging.getLogger(__name__)

# ── Fix-map (ZWJ / Virama composition) ───────────────────────────────────────

_FIX_MAP_PATH = Path(__file__).parent / "compose_fix_map.json"

_fix_map_cache: dict[str, str] | None = None


def _load_fix_map() -> dict[str, str]:
    global _fix_map_cache
    if _fix_map_cache is None:
        with open(_FIX_MAP_PATH, "r", encoding="utf-8") as f:
            _fix_map_cache = json.load(f)
    return _fix_map_cache


# ── Input cleaning ────────────────────────────────────────────────────────────

# Scripts that are not Sinhala, Latin, numbers, or symbols β€” filtered out
_UNSUPPORTED_SCRIPT = re.compile(
    r"[\u0B80-\u0BFF"   # Tamil
    r"\u0900-\u097F"    # Devanagari
    r"\u4E00-\u9FFF"    # CJK Unified Ideographs
    r"\u3040-\u309F"    # Hiragana
    r"\u30A0-\u30FF"    # Katakana
    r"\u0E00-\u0E7F"    # Thai
    r"\u0600-\u06FF"    # Arabic
    r"\u0590-\u05FF"    # Hebrew
    r"\uAC00-\uD7AF]"   # Hangul
)


def _clean(text: str) -> str | None:
    """Remove words in unsupported scripts; return None if nothing remains."""
    words = text.strip().split()
    filtered = [w for w in words if not _UNSUPPORTED_SCRIPT.search(w)]
    return " ".join(filtered) if filtered else None


def _apply_fixes(text: str) -> str:
    """Apply ZWJ/virama composition fixes to mBart50 output."""
    for pattern, replacement in _load_fix_map().items():
        text = re.sub(pattern, replacement, text)
    return text


# ── Transliterator ────────────────────────────────────────────────────────────

class SentenceTransliterator:
    """
    Full-sentence Singlish β†’ Sinhala transliterator (mBart50).

    Loads from Hugging Face Hub on first instantiation.
    Thread-safe for inference (no mutable state after __init__).
    """

    def __init__(
        self,
        model_name: str = DEFAULT_MBART_MODEL,
        device: Optional[str] = None,
    ):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

        logger.info("Loading mBart50 transliterator: %s", model_name)
        self.tokenizer = MBart50Tokenizer.from_pretrained(model_name)
        self.model = MBartForConditionalGeneration.from_pretrained(model_name)
        self.model.to(self.device)
        self.model.eval()

    def transliterate(self, text: str) -> str:
        """
        Transliterate a Singlish sentence to fully-Sinhalized output.

        Args:
            text: Input Singlish sentence (Romanized Sinhala / English mix).

        Returns:
            Sinhala-script output. Returns original text if input is empty
            or consists entirely of unsupported-script characters.
        """
        cleaned = _clean(text)
        if not cleaned:
            return text

        self.tokenizer.src_lang = "si_LK"
        inputs = self.tokenizer(
            cleaned,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=128,
        ).to(self.device)

        with torch.no_grad():
            tokens = self.model.generate(
                **inputs,
                forced_bos_token_id=self.tokenizer.lang_code_to_id["si_LK"],
            )

        output = self.tokenizer.decode(tokens[0], skip_special_tokens=True)
        return _apply_fixes(output)