yetrun's picture
ver1: 实现深度学习训练框架,支持 Wiki GPT 与诗歌生成双任务
a5fd608
"""Wiki 数据集主模块
实现 WikiDataset 类,继承自 DataBundle。
"""
import pathlib
from dataclasses import dataclass, field
from typing import Optional
import tensorflow as tf
from data.base import DataBundle, TokenizerBundle
from data.wiki.loader import doc_load
from data.wiki.transformer import transform
from data.wiki.tokenizer import sentence_piece, character_vectorization
@dataclass
class WikiDataset(DataBundle):
"""Wiki 数据集
将文档加载、分词、统计等功能绑定在一起的数据集类。
Usage:
dataset = WikiDataset(
data_dir="~/data/wiki/mini_c4",
tokenizer_type="sentence_piece" # 或 "character"
)
# 获取文档数据集
doc_ds = dataset.doc_ds()
# 获取 token 数据集
tokens_ds = dataset.tokens_ds(seq_length=256, batch_size=32)
# 打印统计信息
dataset.stat(seq_length=256)
"""
glob_pattern: str = "*"
tokenizer_type: str = "sentence_piece"
_data_path: pathlib.Path = field(init=False, repr=False)
_tokenizer_bundle: Optional[TokenizerBundle] = field(
init=False, repr=False, default=None
)
def __post_init__(self):
self._data_path = pathlib.Path(self.data_dir).expanduser()
def _load_tokenizer(self):
"""懒加载分词器"""
if self._tokenizer_bundle is None:
if self.tokenizer_type == "sentence_piece":
tokenizer, end_of_text, decode = sentence_piece()
elif self.tokenizer_type == "character":
tokenizer, end_of_text, decode = character_vectorization()
else:
raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
vocab_size = tokenizer.vocabulary_size()
self._tokenizer_bundle = TokenizerBundle(
tokenizer=tokenizer,
decode=decode,
end_of_text=end_of_text,
vocab_size=vocab_size
)
def doc_ds(self) -> tf.data.Dataset:
"""返回原始文档数据集
Returns:
TensorFlow Dataset,每个元素是一个文档字符串
"""
return doc_load(self._data_path, glob_pattern=self.glob_pattern)
def tokens_ds(self, seq_length: int, batch_size: int) -> tf.data.Dataset:
"""返回 tokenized 数据集
Args:
seq_length: 序列长度
batch_size: 批次大小
Returns:
TensorFlow Dataset,每个元素是 (input_ids, target_ids) 对
"""
self._load_tokenizer()
ds = self.doc_ds()
return transform(
ds=ds,
tokenizer=self._tokenizer_bundle.tokenizer,
end_of_text=self._tokenizer_bundle.end_of_text,
sequence_length=seq_length,
batch_size=batch_size,
)
def tokenizer_bundle(self) -> TokenizerBundle:
"""返回分词器信息"""
self._load_tokenizer()
return self._tokenizer_bundle