Automatic Speech Recognition
Transformers
Safetensors
Khmer
English
troryongasr
custom_code
File size: 6,151 Bytes
0133579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# Author: KHUN Kimang
# Date: March 2026
# KrorngAI
# Inspired by https://github.com/openai/whisper/blob/main/whisper/tokenizer.py

from typing import Optional, Tuple, List
from dataclasses import dataclass, field
from functools import cached_property
from enum import Enum
from transformers import LlamaTokenizer, PreTrainedTokenizer
import json


LANGUAGES = {
    "km": "khmer",
    "en": "english"
}
TO_LANGUAGE_CODE = {
    **{lang: code for code, lang in LANGUAGES.items()},
}

class ASRSpecialTokens(str, Enum):
    km_token = "<|km|>" # language token must be added to lm_head of Decoder Model
    en_token = "<|en|>" # language token must be added to lm_head of Decoder Model
    transcribe = "<|transcribe|>"
    translate = "<|translate|>"
    no_speech = "<|nospeech|>"
    @classmethod
    def list(cls):
        return [c.value for c in cls]


class TrorYongASRTokenizer(LlamaTokenizer):
    """
    Tokenizer for the ASR task.
    It supports only two languages: Khmer and English.
    It does not support timestamps.
    """

    def __init__(
        self,
        language: Optional[str] = None,
        task: Optional[str] = None,
        *args,
        **kwargs
    ):
        self.language = language
        self.task = task

        super().__init__(
            *args,
            **kwargs
        )
        self.add_special_tokens({
            "additional_special_tokens": ASRSpecialTokens.list()
        })

        self.special_tokens = dict()
        for special in self.all_special_tokens:
            special_id = self.encode(special, add_special_tokens=False)[0]
            self.special_tokens[special] = special_id

        sot: int = self.special_tokens["<s>"]
        translate: int = self.special_tokens["<|translate|>"]
        transcribe: int = self.special_tokens["<|transcribe|>"]

        sot_sequence = [sot]
        if self.language is not None:
            language = self.language.lower()
            if language not in LANGUAGES:
                if language in TO_LANGUAGE_CODE:
                    language = TO_LANGUAGE_CODE[language]
                else:
                    raise ValueError(f"Unsupported language: {language}")

            self.language = language
            lang_id = self.encode(f"<|{language}|>", add_special_tokens=False)[0]
            sot_sequence.append(lang_id)
        if self.task is not None:
            task_token: int = transcribe if self.task == "transcribe" else translate
            sot_sequence.append(task_token)

        self.sot_sequence = tuple(sot_sequence)

    def encode(self, text, **kwargs) -> List[int]:
        encoding = super().encode(text, **kwargs)
        return encoding if encoding[0] != 29871 else encoding[1:] # 29871 is whitespace for TinyKhmerTokenizer

    def __call__(self, text: Optional[str] = None) -> List[int]:
        encoding = self.encode(text, add_special_tokens=False)
        return [*self.sot_sequence] + encoding

    @cached_property
    def eot(self) -> int:
        return self.special_tokens["</s>"]

    @cached_property
    def transcribe(self) -> int:
        return self.special_tokens["<|transcribe|>"]

    @cached_property
    def translate(self) -> int:
        return self.special_tokens["<|translate|>"]

    @cached_property
    def sot(self) -> int:
        return self.special_tokens["<s>"]

    @cached_property
    def no_speech(self) -> int:
        return self.special_tokens["<|nospeech|>"]

    @cached_property
    def language_token(self) -> int:
        """Returns the token id corresponding to the value of the `language` field"""
        if self.language is None:
            raise ValueError("This tokenizer does not have language token configured")

        return self.to_language_token(self.language)

    def to_language_token(self, language):
        if token := self.special_tokens.get(f"<|{language}|>", None):
            return token

        raise KeyError(f"Language {language} not found in tokenizer.")

    @cached_property
    def all_language_tokens(self) -> Tuple[int]:
        result = []
        for token, token_id in self.special_tokens.items():
            if token.strip("<|>") in LANGUAGES:
                result.append(token_id)
        return tuple(result)

    @cached_property
    def all_language_codes(self) -> Tuple[str]:
        return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)

    @cached_property
    def non_speech_tokens(self) -> Tuple[int]:
        """
        Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
        annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.

        - ♪♪♪
        - ( SPEAKING FOREIGN LANGUAGE )
        - [DAVID] Hey there,

        keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
        """
        symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
        symbols += (
            "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
        )

        # symbols that may be a single token or multiple tokens depending on the tokenizer.
        # In case they're multiple tokens, suppress the first token, which is safe because:
        # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
        # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
        miscellaneous = set("♩♪♫♬♭♮♯")
        assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)

        # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
        result = {self.encode(" -", add_special_tokens=False)[0], self.encode(" '", add_special_tokens=False)[0]}
        for symbol in symbols + list(miscellaneous):
            for tokens in [
                self.encode(symbol, add_special_tokens=False),
                self.encode(" " + symbol, add_special_tokens=False),
            ]:
                if len(tokens) == 1 or symbol in miscellaneous:
                    result.add(tokens[0])

        return tuple(sorted(result))