tiny-sentiment-classifier / tokenization_tinytransformer.py
huiqian's picture
Upload 12 files
846dc7c verified
# tokenization_tinytransformer.py (最完整修复版)
from transformers import PreTrainedTokenizer
import json
from typing import List, Dict, Optional
class TinyTokenizer(PreTrainedTokenizer):
vocab_files_names = {"vocab_file": "vocab.json"}
pretrained_vocab_files_map = {}
max_model_input_sizes = {"tinytransformer": 512}
model_input_names = ["input_ids", "attention_mask"]
def __init__(self, vocab_file: Optional[str] = None, **kwargs):
# 特殊 token
self.special_tokens = {
"[PAD]": 0,
"[UNK]": 1,
"[CLS]": 2,
"[SEP]": 3,
}
# 构建 vocab
self.vocab: Dict[str, int] = self.special_tokens.copy()
offset = len(self.vocab)
# ASCII + 常用字符
for i in range(32, 127):
char = chr(i)
self.vocab[char] = offset + i - 32
# 支持中文(常用汉字范围,可扩展)
for i in range(0x4e00, 0x9fff + 1):
char = chr(i)
if char not in self.vocab:
self.vocab[char] = len(self.vocab)
self.id_to_token = {v: k for k, v in self.vocab.items()}
# 设置特殊 token id
self.pad_token_id = 0
self.unk_token_id = 1
self.cls_token_id = 2
self.sep_token_id = 3
super().__init__(
pad_token="[PAD]",
unk_token="[UNK]",
cls_token="[CLS]",
sep_token="[SEP]",
**kwargs
)
def get_vocab(self) -> Dict[str, int]:
return self.vocab.copy()
@property
def vocab_size(self) -> int:
return len(self.vocab)
def _tokenize(self, text: str) -> List[str]:
return list(text) # 字符级分词
def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]:
return [self.vocab.get(t, self.unk_token_id) for t in tokens]
def convert_ids_to_tokens(self, ids: List[int]) -> List[str]:
return [self.id_to_token.get(i, "[UNK]") for i in ids]
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + token_ids_1 + [self.sep_token_id]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
if token_ids_1 is None:
return [0] * len([self.cls_token_id] + token_ids_0 + [self.sep_token_id])
len0 = len([self.cls_token_id] + token_ids_0 + [self.sep_token_id])
len1 = len(token_ids_1 + [self.sep_token_id])
return [0] * len0 + [1] * len1
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
vocab_file = f"{filename_prefix}vocab.json" if filename_prefix else "vocab.json"
vocab_path = f"{save_directory}/{vocab_file}"
with open(vocab_path, "w", encoding="utf-8") as f:
json.dump(self.vocab, f, ensure_ascii=False, indent=2)
return (vocab_path,)
# # 文件最底部,类定义之后
# from transformers import TOKENIZER_MAPPING_NAMES
# # 注册(只执行一次,放在这里最安全)
# TOKENIZER_MAPPING_NAMES["tinytransformer"] = "TinyTokenizer"