File size: 2,919 Bytes
f6f45d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
Inference helper — given a romanized word, return top-K Sinhala candidates
using beam search on the fine-tuned ByT5 model.

Usage:
    from seq2seq.infer import Transliterator
    t = Transliterator()
    print(t.candidates("videowe", k=5))
    # ['වීඩියොවේ', 'වීඩියොවී', 'වීඩියොව', ...]
"""

from __future__ import annotations
from pathlib import Path
from typing import Optional

import torch
from transformers import ByT5Tokenizer, T5ForConditionalGeneration

DEFAULT_MODEL_PATH = Path(__file__).parent / "byt5-singlish-sinhala" / "final"


class Transliterator:
    def __init__(self, model_path: str | Path = DEFAULT_MODEL_PATH, device: Optional[str] = None):
        # Keep as string — Path() would convert '/' to '\' on Windows, breaking HF Hub IDs
        model_path = str(model_path)
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = ByT5Tokenizer.from_pretrained(model_path)
        self.model = T5ForConditionalGeneration.from_pretrained(model_path)
        self.model.to(self.device)
        self.model.eval()

    def candidates(self, word: str, k: int = 5) -> list[str]:
        """Return top-k Sinhala transliteration candidates for a single word."""
        return self.batch_candidates([word], k=k)[0]

    def batch_candidates(self, words: list[str], k: int = 5) -> list[list[str]]:
        """
        Return top-k Sinhala candidates for each word in a single forward pass.
        Much faster than calling candidates() per word on a long sentence.
        """
        lowered = [w.lower() for w in words]
        inputs = self.tokenizer(
            lowered,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=64,
        ).to(self.device)

        n = len(words)
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                num_beams=max(k, 5),
                num_return_sequences=k,
                max_new_tokens=64,
                early_stopping=True,
            )

        # outputs shape: (n * k, seq_len) — k sequences per input, grouped
        results: list[list[str]] = []
        for i in range(n):
            seen: set[str] = set()
            cands: list[str] = []
            for seq in outputs[i * k : (i + 1) * k]:
                text = self.tokenizer.decode(seq, skip_special_tokens=True).strip()
                if text and text not in seen:
                    seen.add(text)
                    cands.append(text)
            results.append(cands)

        return results


if __name__ == "__main__":
    import sys
    words = sys.argv[1:] if len(sys.argv) > 1 else ["wadi"]
    t = Transliterator()
    for word in words:
        print(f"Candidates for '{word}':")
        for c in t.candidates(word):
            print(f"  {c}")
        print()