| """分词器:CharTokenizer,从切割文本得到 token IDs"""
|
|
|
|
|
| import json
|
| from pathlib import Path
|
| from collections import Counter
|
|
|
|
|
| class CharTokenizer:
|
| """字符级分词器,英文的"""
|
|
|
| def __init__(self, vocab_path=None):
|
| """
|
| 初始化分词器
|
|
|
| 参数:
|
| vocab_path: 词汇表文件路径(JSON 格式)
|
| """
|
| self.vocab_path = vocab_path
|
| self.char_to_id = {}
|
| self.id_to_char = {}
|
| self.vocab_size = 0
|
|
|
| if vocab_path and Path(vocab_path).exists():
|
| self.load_vocab(vocab_path)
|
|
|
| def build_vocab(self, texts):
|
| """
|
| 从文本构建词汇表
|
|
|
| 参数:
|
| texts: 文本列表或单个文本字符串
|
| """
|
|
|
| if isinstance(texts, str):
|
| texts = [texts]
|
|
|
|
|
| all_chars = ''.join(texts)
|
|
|
|
|
| char_counts = Counter(all_chars)
|
|
|
|
|
| self.char_to_id = {
|
| '<unk>': 0,
|
|
|
|
|
|
|
|
|
| }
|
|
|
|
|
| sorted_chars = sorted(char_counts.items(), key=lambda x: x[1], reverse=True)
|
| for char, count in sorted_chars:
|
| if char not in self.char_to_id:
|
| self.char_to_id[char] = len(self.char_to_id)
|
|
|
|
|
| self.id_to_char = {id: char for char, id in self.char_to_id.items()}
|
| self.vocab_size = len(self.char_to_id)
|
|
|
| def encode(self, text):
|
| """
|
| 编码:将文本切割成字符,然后转换为 ID
|
|
|
| 参数:
|
| text: 输入文本字符串
|
|
|
| 返回:
|
| token_ids: token ID 列表
|
| """
|
| token_ids = []
|
|
|
| for char in text:
|
|
|
| char_id = self.char_to_id.get(char, self.char_to_id.get('<unk>', 0))
|
| token_ids.append(char_id)
|
| return token_ids
|
|
|
| def decode(self, token_ids):
|
| """
|
| 解码:将 ID 列表转换回文本
|
|
|
| 参数:
|
| token_ids: token ID 列表或张量
|
|
|
| 返回:
|
| text: 解码后的文本字符串
|
| """
|
|
|
| if hasattr(token_ids, 'tolist'):
|
| token_ids = token_ids.tolist()
|
|
|
|
|
| chars = []
|
| for id in token_ids:
|
| char = self.id_to_char.get(id, '<unk>')
|
|
|
| if char not in ['<unk>']:
|
| chars.append(char)
|
|
|
|
|
| return ''.join(chars)
|
|
|
| def save_vocab(self, vocab_path):
|
| """保存词汇表到文件"""
|
| vocab_data = {
|
| 'char_to_id': self.char_to_id,
|
| 'id_to_char': {str(k): v for k, v in self.id_to_char.items()},
|
| 'vocab_size': self.vocab_size
|
| }
|
|
|
| Path(vocab_path).parent.mkdir(parents=True, exist_ok=True)
|
|
|
| with open(vocab_path, 'w', encoding='utf-8') as f:
|
| json.dump(vocab_data, f, ensure_ascii=False, indent=2)
|
|
|
| def load_vocab(self, vocab_path):
|
| """从文件加载词汇表"""
|
| with open(vocab_path, 'r', encoding='utf-8') as f:
|
| vocab_data = json.load(f)
|
|
|
| self.char_to_id = vocab_data['char_to_id']
|
| self.id_to_char = {int(k): v for k, v in vocab_data['id_to_char'].items()}
|
| self.vocab_size = vocab_data['vocab_size'] |