# Author: KHUN Kimang # Date: March 2026 # KrorngAI # Inspired by https://github.com/openai/whisper/blob/main/whisper/tokenizer.py from typing import Optional, Tuple, List from dataclasses import dataclass, field from functools import cached_property from enum import Enum from transformers import LlamaTokenizer, PreTrainedTokenizer import json LANGUAGES = { "km": "khmer", "en": "english" } TO_LANGUAGE_CODE = { **{lang: code for code, lang in LANGUAGES.items()}, } class ASRSpecialTokens(str, Enum): km_token = "<|km|>" # language token must be added to lm_head of Decoder Model en_token = "<|en|>" # language token must be added to lm_head of Decoder Model transcribe = "<|transcribe|>" translate = "<|translate|>" no_speech = "<|nospeech|>" @classmethod def list(cls): return [c.value for c in cls] class TrorYongASRTokenizer(LlamaTokenizer): """ Tokenizer for the ASR task. It supports only two languages: Khmer and English. It does not support timestamps. """ def __init__( self, language: Optional[str] = None, task: Optional[str] = None, *args, **kwargs ): self.language = language self.task = task super().__init__( *args, **kwargs ) self.add_special_tokens({ "additional_special_tokens": ASRSpecialTokens.list() }) self.special_tokens = dict() for special in self.all_special_tokens: special_id = self.encode(special, add_special_tokens=False)[0] self.special_tokens[special] = special_id sot: int = self.special_tokens[""] translate: int = self.special_tokens["<|translate|>"] transcribe: int = self.special_tokens["<|transcribe|>"] sot_sequence = [sot] if self.language is not None: language = self.language.lower() if language not in LANGUAGES: if language in TO_LANGUAGE_CODE: language = TO_LANGUAGE_CODE[language] else: raise ValueError(f"Unsupported language: {language}") self.language = language lang_id = self.encode(f"<|{language}|>", add_special_tokens=False)[0] sot_sequence.append(lang_id) if self.task is not None: task_token: int = transcribe if self.task == "transcribe" else translate sot_sequence.append(task_token) self.sot_sequence = tuple(sot_sequence) def encode(self, text, **kwargs) -> List[int]: encoding = super().encode(text, **kwargs) return encoding if encoding[0] != 29871 else encoding[1:] # 29871 is whitespace for TinyKhmerTokenizer def __call__(self, text: Optional[str] = None) -> List[int]: encoding = self.encode(text, add_special_tokens=False) return [*self.sot_sequence] + encoding @cached_property def eot(self) -> int: return self.special_tokens[""] @cached_property def transcribe(self) -> int: return self.special_tokens["<|transcribe|>"] @cached_property def translate(self) -> int: return self.special_tokens["<|translate|>"] @cached_property def sot(self) -> int: return self.special_tokens[""] @cached_property def no_speech(self) -> int: return self.special_tokens["<|nospeech|>"] @cached_property def language_token(self) -> int: """Returns the token id corresponding to the value of the `language` field""" if self.language is None: raise ValueError("This tokenizer does not have language token configured") return self.to_language_token(self.language) def to_language_token(self, language): if token := self.special_tokens.get(f"<|{language}|>", None): return token raise KeyError(f"Language {language} not found in tokenizer.") @cached_property def all_language_tokens(self) -> Tuple[int]: result = [] for token, token_id in self.special_tokens.items(): if token.strip("<|>") in LANGUAGES: result.append(token_id) return tuple(result) @cached_property def all_language_codes(self) -> Tuple[str]: return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens) @cached_property def non_speech_tokens(self) -> Tuple[int]: """ Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech annotations, to prevent sampling texts that are not actually spoken in the audio, e.g. - ♪♪♪ - ( SPEAKING FOREIGN LANGUAGE ) - [DAVID] Hey there, keeping basic punctuations like commas, periods, question marks, exclamation points, etc. """ symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』') symbols += ( "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() ) # symbols that may be a single token or multiple tokens depending on the tokenizer. # In case they're multiple tokens, suppress the first token, which is safe because: # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress # in generations, and in the 3-byte UTF-8 representation they share the first two bytes. miscellaneous = set("♩♪♫♬♭♮♯") assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word result = {self.encode(" -", add_special_tokens=False)[0], self.encode(" '", add_special_tokens=False)[0]} for symbol in symbols + list(miscellaneous): for tokens in [ self.encode(symbol, add_special_tokens=False), self.encode(" " + symbol, add_special_tokens=False), ]: if len(tokens) == 1 or symbol in miscellaneous: result.add(tokens[0]) return tuple(sorted(result))