FireRedASR / fireredasr /data /token_dict.py
luoqiu's picture
添加相关代码
4ef9456
import logging
class TokenDict:
def __init__(self, dict_path, unk=""):
assert dict_path != ""
self.id2word, self.word2id = self.read_dict(dict_path)
self.unk = unk
assert unk == "" or unk in self.word2id
self.unkid = self.word2id[unk] if unk else -1
def get(self, key, default):
if type(default) == str:
default = self.word2id[default]
return self.word2id.get(key, default)
def __getitem__(self, key):
if type(key) == str:
if self.unk:
return self.word2id.get(key, self.word2id[self.unk])
else:
return self.word2id[key]
elif type(key) == int:
return self.id2word[key]
else:
raise TypeError("Key should be str or int")
def __len__(self):
return len(self.id2word)
def __contains__(self, query):
if type(query) == str:
return query in self.word2id
elif type(query) == int:
return query in self.id2word
else:
raise TypeError("query should be str or int")
def read_dict(self, dict_path):
id2word, word2id = [], {}
with open(dict_path, encoding='utf8') as f:
for i, line in enumerate(f):
tokens = line.strip().split()
if len(tokens) >= 2:
word, index = tokens[0], int(tokens[1])
elif len(tokens) == 1:
word, index = tokens[0], i
else: # empty line or space
logging.info(f"Find empty line or space '{line.strip()}' in {dict_path}:L{i}, set to ' '")
word, index = " ", i
assert len(id2word) == index
assert len(word2id) == index
if word == "<space>":
logging.info(f"NOTE: Find <space> in {dict_path}:L{i} and convert it to ' '")
word = " "
word2id[word] = index
id2word.append(word)
assert len(id2word) == len(word2id)
return id2word, word2id