sinhala-tts / scripts /sinhala_tokenizer.py
outlawmold's picture
Add Sinhala grapheme tokenizer for TTS
be6f20b verified
#!/usr/bin/env python3
"""
=============================================================
Sinhala Grapheme Tokenizer for TTS
=============================================================
Tokenizes Sinhala text into orthographic syllables (aksharas)
for FastPitch character-level TTS training.
Sinhala is a Brahmic abugida (U+0D80–U+0DFF):
- Consonants carry an inherent /a/ vowel
- Al-lakuna (virama) ් suppresses the inherent vowel
- ZWJ (U+200D) forms conjunct consonants: ක් + ZWJ + ෂ → ක්‍ෂ
- Vowel diacritics modify the inherent vowel
An akshara (grapheme cluster) is the minimal unit:
- Independent vowel: අ, ආ, ඉ, ඊ, ...
- Consonant + optional vowel sign: ක, කා, කි, ...
- Conjunct: consonant + virama + ZWJ + consonant + vowel sign
Why aksharas (not phonemes):
IndicTTS (arxiv:2211.09536) proved character-level input works
for all 13 Indic languages tested. No phoneme dictionary needed.
Usage:
from sinhala_tokenizer import SinhalaTokenizer
tok = SinhalaTokenizer()
tokens = tok.tokenize("ශ්‍රී ලංකාව")
ids = tok.encode("ශ්‍රී ලංකාව")
text = tok.decode(ids)
=============================================================
"""
import json
import re
import unicodedata
from pathlib import Path
from typing import Dict, List, Optional, Tuple
# ============================================================
# Unicode Constants for Sinhala (U+0D80–U+0DFF)
# ============================================================
# Independent vowels (can stand alone)
SINHALA_VOWELS = (
'\u0D85' # අ (a)
'\u0D86' # ආ (aa)
'\u0D87' # ඇ (ae)
'\u0D88' # ඈ (aae)
'\u0D89' # ඉ (i)
'\u0D8A' # ඊ (ii)
'\u0D8B' # උ (u)
'\u0D8C' # ඌ (uu)
'\u0D8D' # ඍ (ri)
'\u0D8E' # ඎ (rii) - rare
'\u0D8F' # ඏ (li) - rare
'\u0D90' # ඐ (lii) - rare
'\u0D91' # එ (e)
'\u0D92' # ඒ (ee)
'\u0D93' # ඓ (ai)
'\u0D94' # ඔ (o)
'\u0D95' # ඕ (oo)
'\u0D96' # ඖ (au)
)
# Consonants
SINHALA_CONSONANTS = (
'\u0D9A' # ක (ka)
'\u0D9B' # ඛ (kha)
'\u0D9C' # ග (ga)
'\u0D9D' # ඝ (gha)
'\u0D9E' # ඞ (nga)
'\u0D9F' # ඟ (nnga) - sanyaka (prenasalized)
'\u0DA0' # ච (cha)
'\u0DA1' # ඡ (chha)
'\u0DA2' # ජ (ja)
'\u0DA3' # ඣ (jha)
'\u0DA4' # ඤ (nya)
'\u0DA5' # ඥ (jnya)
'\u0DA6' # ඦ (nyja) - sanyaka
'\u0DA7' # ට (tta)
'\u0DA8' # ඨ (ttha)
'\u0DA9' # ඩ (dda)
'\u0DAA' # ඪ (ddha)
'\u0DAB' # ණ (nna)
'\u0DAC' # ඬ (ndda) - sanyaka
'\u0DAD' # ත (ta)
'\u0DAE' # ථ (tha)
'\u0DAF' # ද (da)
'\u0DB0' # ධ (dha)
'\u0DB1' # න (na)
'\u0DB3' # ඳ (nda) - sanyaka
'\u0DB4' # ප (pa)
'\u0DB5' # ඵ (pha)
'\u0DB6' # බ (ba)
'\u0DB7' # භ (bha)
'\u0DB8' # ම (ma)
'\u0DB9' # ඹ (mba) - sanyaka
'\u0DBA' # ය (ya)
'\u0DBB' # ර (ra)
'\u0DBD' # ල (la)
'\u0DC0' # ව (va/wa)
'\u0DC1' # ශ (sha)
'\u0DC2' # ෂ (ssa)
'\u0DC3' # ස (sa)
'\u0DC4' # හ (ha)
'\u0DC5' # ළ (lla)
'\u0DC6' # ෆ (fa) - used for foreign words
)
# Virama (Al-Lakuna) — kills inherent vowel
VIRAMA = '\u0DCA' # ්
# Zero-Width Joiner — forms visual conjuncts
ZWJ = '\u200D'
# Dependent vowel signs (diacritics on consonants)
SINHALA_VOWEL_SIGNS = (
'\u0DCF' # ා (aa)
'\u0DD0' # ැ (ae)
'\u0DD1' # ෑ (aae)
'\u0DD2' # ි (i)
'\u0DD3' # ී (ii)
'\u0DD4' # ු (u)
'\u0DD6' # ූ (uu)
'\u0DD8' # ෘ (ri)
'\u0DD9' # ෙ (e) — pre-base
'\u0DDA' # ේ (ee)
'\u0DDB' # ෛ (ai)
'\u0DDC' # ො (o)
'\u0DDD' # ෝ (oo)
'\u0DDE' # ෞ (au)
'\u0DDF' # ෟ (li) — rare
)
# Anusvara and Visarga
ANUSVARA = '\u0D82' # ං
VISARGA = '\u0D83' # ඃ
# Sinhala digits
SINHALA_DIGITS = ''.join(chr(c) for c in range(0x0DE6, 0x0DF0)) # ෦-෯
# Sets for fast lookup
VOWEL_SET = set(SINHALA_VOWELS)
CONSONANT_SET = set(SINHALA_CONSONANTS)
VOWEL_SIGN_SET = set(SINHALA_VOWEL_SIGNS)
MODIFIER_SET = VOWEL_SIGN_SET | {ANUSVARA, VISARGA, VIRAMA}
# ============================================================
# Regex pattern for Sinhala grapheme clusters (aksharas)
# ============================================================
# An akshara is:
# (Consonant (Virama ZWJ? Consonant)*) VowelSign? (Anusvara|Visarga)?
# | IndependentVowel (Anusvara|Visarga)?
# | Anusvara | Visarga (standalone)
_C = f'[{"".join(SINHALA_CONSONANTS)}]' # consonant
_V = f'[{"".join(SINHALA_VOWELS)}]' # independent vowel
_VS = f'[{"".join(SINHALA_VOWEL_SIGNS)}]' # vowel sign
_VIR = re.escape(VIRAMA)
_ZWJ = re.escape(ZWJ)
_ANU = re.escape(ANUSVARA)
_VIS = re.escape(VISARGA)
# Conjunct consonant: virama + optional ZWJ + consonant
_CONJUNCT = f'{_VIR}{_ZWJ}?{_C}'
# Full akshara pattern
AKSHARA_PATTERN = re.compile(
f'({_C}(?:{_CONJUNCT})*{_VS}?[{re.escape(ANUSVARA)}{re.escape(VISARGA)}]?)' # consonant cluster
f'|({_V}[{re.escape(ANUSVARA)}{re.escape(VISARGA)}]?)' # independent vowel
f'|([{re.escape(ANUSVARA)}{re.escape(VISARGA)}])' # standalone modifier
f'|([0-9{SINHALA_DIGITS}]+)' # number
f'|([!?.,;:\'"\\-–—…])' # punctuation
f'|( +)' # space(s)
f'|(.)', # anything else (1 char)
re.UNICODE
)
class SinhalaTokenizer:
"""
Sinhala grapheme (akshara) tokenizer for TTS.
Designed for Coqui-TTS FastPitch training.
Maps Sinhala text → grapheme cluster sequence → integer IDs.
"""
# Special tokens
PAD = "<PAD>"
BOS = "<BOS>"
EOS = "<EOS>"
UNK = "<UNK>"
BLANK = "<BLNK>"
SPACE = " "
def __init__(self, vocab_path: Optional[str] = None):
"""
Initialize tokenizer.
If vocab_path is provided, load existing vocabulary.
Otherwise, build default vocabulary from Unicode block.
"""
if vocab_path and Path(vocab_path).exists():
self.load_vocab(vocab_path)
else:
self._build_default_vocab()
def _build_default_vocab(self):
"""Build vocabulary from Sinhala Unicode block + common tokens."""
self.token2id: Dict[str, int] = {}
self.id2token: Dict[int, str] = {}
# Special tokens first (indices 0-4)
specials = [self.PAD, self.BOS, self.EOS, self.UNK, self.BLANK]
for i, tok in enumerate(specials):
self.token2id[tok] = i
self.id2token[i] = tok
idx = len(specials)
# Space
self.token2id[self.SPACE] = idx
self.id2token[idx] = self.SPACE
idx += 1
# Punctuation
for p in '!?.,;:\'"\\-':
self.token2id[p] = idx
self.id2token[idx] = p
idx += 1
# Independent vowels
for v in SINHALA_VOWELS:
self.token2id[v] = idx
self.id2token[idx] = v
idx += 1
# Consonants
for c in SINHALA_CONSONANTS:
self.token2id[c] = idx
self.id2token[idx] = c
idx += 1
# Vowel signs
for vs in SINHALA_VOWEL_SIGNS:
self.token2id[vs] = idx
self.id2token[idx] = vs
idx += 1
# Modifiers
for m in [VIRAMA, ZWJ, ANUSVARA, VISARGA]:
if m not in self.token2id:
self.token2id[m] = idx
self.id2token[idx] = m
idx += 1
# Arabic digits
for d in '0123456789':
if d not in self.token2id:
self.token2id[d] = idx
self.id2token[idx] = d
idx += 1
# Sinhala digits
for d in SINHALA_DIGITS:
if d not in self.token2id:
self.token2id[d] = idx
self.id2token[idx] = d
idx += 1
self.vocab_size = len(self.token2id)
def normalize(self, text: str) -> str:
"""
Normalize Sinhala text for TTS.
- NFC normalization
- Remove ZWNJ (keep ZWJ)
- Normalize punctuation
- Collapse whitespace
"""
text = unicodedata.normalize('NFC', text)
text = text.replace('\u200C', '') # Remove ZWNJ
text = text.replace('"', '"').replace('"', '"')
text = text.replace(''', "'").replace(''', "'")
text = text.replace(';', ',').replace(':', ',')
text = text.replace('(', '').replace(')', '')
text = ' '.join(text.split())
return text.strip()
def tokenize(self, text: str, normalize: bool = True) -> List[str]:
"""
Tokenize text into grapheme clusters.
Returns list of tokens (aksharas, punctuation, spaces).
"""
if normalize:
text = self.normalize(text)
tokens = []
for match in AKSHARA_PATTERN.finditer(text):
token = match.group(0)
if token:
tokens.append(token)
return tokens
def encode(self, text: str, add_bos: bool = True, add_eos: bool = True) -> List[int]:
"""
Encode text to integer IDs.
For multi-codepoint aksharas (conjuncts), each codepoint gets its own ID.
This is by design — FastPitch's character embedding handles the sequence,
and the attention/aligner learns the mapping to mel frames.
"""
if isinstance(text, str):
text = self.normalize(text)
ids = []
if add_bos:
ids.append(self.token2id[self.BOS])
for char in text:
if char in self.token2id:
ids.append(self.token2id[char])
else:
ids.append(self.token2id[self.UNK])
if add_eos:
ids.append(self.token2id[self.EOS])
return ids
def decode(self, ids: List[int], strip_special: bool = True) -> str:
"""Decode integer IDs back to text."""
specials = {self.PAD, self.BOS, self.EOS, self.BLANK, self.UNK}
chars = []
for idx in ids:
token = self.id2token.get(idx, self.UNK)
if strip_special and token in specials:
continue
chars.append(token)
return ''.join(chars)
def get_characters_string(self) -> str:
"""
Get all characters as a single string for Coqui-TTS config.
Usage in Coqui config:
config.characters.characters = tokenizer.get_characters_string()
"""
# All Sinhala characters except specials, space, and punctuation
chars = []
for token, idx in sorted(self.token2id.items(), key=lambda x: x[1]):
if token in {self.PAD, self.BOS, self.EOS, self.UNK, self.BLANK}:
continue
if token == self.SPACE:
continue
if token in '!?.,;:\'"\\-':
continue
if len(token) == 1: # single characters only
chars.append(token)
return ''.join(chars)
def get_punctuations_string(self) -> str:
"""Get punctuation characters for Coqui-TTS config."""
return '!?.,;:\'"- '
def save_vocab(self, path: str):
"""Save vocabulary to JSON."""
data = {
"token2id": self.token2id,
"vocab_size": self.vocab_size,
"description": "Sinhala TTS grapheme tokenizer vocabulary",
}
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
def load_vocab(self, path: str):
"""Load vocabulary from JSON."""
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
self.token2id = data["token2id"]
self.id2token = {int(v): k for k, v in self.token2id.items()}
self.vocab_size = len(self.token2id)
def expand_vocab(self, new_tokens: List[str]):
"""Add new tokens to vocabulary (e.g., from training data)."""
for token in new_tokens:
if token not in self.token2id:
idx = len(self.token2id)
self.token2id[token] = idx
self.id2token[idx] = token
self.vocab_size = len(self.token2id)
def build_vocab_from_corpus(self, texts: List[str]):
"""
Scan corpus and add any characters not in default vocab.
Call this after loading your training data to ensure
all characters are covered.
"""
unseen = set()
for text in texts:
text = self.normalize(text)
for char in text:
if char not in self.token2id:
unseen.add(char)
if unseen:
print(f"Adding {len(unseen)} new characters to vocab: {unseen}")
self.expand_vocab(sorted(unseen))
def __len__(self):
return self.vocab_size
def __repr__(self):
return (f"SinhalaTokenizer(vocab_size={self.vocab_size}, "
f"vowels={len(SINHALA_VOWELS)}, "
f"consonants={len(SINHALA_CONSONANTS)}, "
f"vowel_signs={len(SINHALA_VOWEL_SIGNS)})")
# ============================================================
# Utility functions
# ============================================================
def syllabify(text: str) -> List[str]:
"""
Quick syllabification: split Sinhala text into aksharas.
Standalone function (no tokenizer instance needed).
"""
text = unicodedata.normalize('NFC', text)
syllables = []
for match in AKSHARA_PATTERN.finditer(text):
token = match.group(0)
if token and token.strip():
syllables.append(token)
return syllables
def count_aksharas(text: str) -> int:
"""Count number of aksharas (grapheme clusters) in text."""
return len(syllabify(text))
def is_sinhala(text: str) -> bool:
"""Check if text is predominantly Sinhala."""
sinhala_chars = sum(1 for c in text if '\u0D80' <= c <= '\u0DFF')
total_chars = sum(1 for c in text if not c.isspace())
if total_chars == 0:
return False
return sinhala_chars / total_chars > 0.5
def get_coqui_characters_config() -> dict:
"""
Get character config dict for Coqui-TTS FastPitch.
Usage:
from TTS.tts.configs.shared_configs import CharactersConfig
char_config = CharactersConfig(**get_coqui_characters_config())
"""
tok = SinhalaTokenizer()
return {
"pad": SinhalaTokenizer.PAD,
"eos": SinhalaTokenizer.EOS,
"bos": SinhalaTokenizer.BOS,
"blank": SinhalaTokenizer.BLANK,
"characters": tok.get_characters_string(),
"punctuations": tok.get_punctuations_string(),
"phonemes": None,
"is_unique": True,
}
# ============================================================
# CLI: test the tokenizer
# ============================================================
if __name__ == "__main__":
tok = SinhalaTokenizer()
print(f"Tokenizer: {tok}")
print(f"Vocab size: {tok.vocab_size}")
print()
# Test sentences
test_texts = [
"ශ්‍රී ලංකාව", # Sri Lanka (with conjunct ශ්‍ර)
"මෙය උදාහරණ වාක්‍යයකි.", # "This is an example sentence."
"සිංහල භාෂාව ඉතා සුන්දරයි!", # "Sinhala language is very beautiful!"
"ක්‍රිකට් ක්‍රීඩාව", # "Cricket sport" (conjuncts)
"බුද්ධ ශාසනය", # "Buddhist dispensation"
"123 දවස්", # "123 days" (mixed digits)
]
for text in test_texts:
print(f"Input: {text}")
tokens = tok.tokenize(text)
print(f"Aksharas: {tokens}")
ids = tok.encode(text)
print(f"IDs: {ids}")
decoded = tok.decode(ids)
print(f"Decoded: {decoded}")
# Verify round-trip
normalized = tok.normalize(text)
match = "✓" if decoded == normalized else "✗"
print(f"Match: {match}")
print()
# Print character set for Coqui config
print("=" * 60)
print("FOR COQUI-TTS CONFIG:")
print("=" * 60)
config = get_coqui_characters_config()
print(f"characters: {repr(config['characters'])}")
print(f"punctuations: {repr(config['punctuations'])}")
print(f"Total unique characters: {len(config['characters'])}")
# Save vocab
vocab_path = "sinhala_vocab.json"
tok.save_vocab(vocab_path)
print(f"\nVocab saved to: {vocab_path}")