yetrun's picture
ver1: 实现深度学习训练框架,支持 Wiki GPT 与诗歌生成双任务
a5fd608
"""Wiki 数据集文档加载模块
从 mini_c4 格式加载文档数据集。
"""
import pathlib
import tensorflow as tf
def doc_load(
data_dir: pathlib.Path, glob_pattern: str = "*", cycle_length: int = 32
) -> tf.data.Dataset:
"""加载并处理文档数据集为 TensorFlow Dataset。
递归查找指定目录下匹配 glob_pattern 的所有文件,使用 doc_extract 函数
将每个文件转换为 TensorFlow Dataset,然后使用 interleave 进行并行处理。
目录下的文件格式要求每行一个文档,其中的换行符使用 "\\n" 转义。
Args:
data_dir: 数据目录路径
glob_pattern: 文件匹配模式,如 "*.txt",默认为 "*" 匹配所有文件
cycle_length: interleave 的 cycle_length 参数,控制并行处理的文件数量,默认为 32
Returns:
合并后的 TensorFlow Dataset,包含所有文件处理后的数据
"""
# 获取所有文件(过滤掉目录),递归查找子目录
files = [str(file) for file in data_dir.rglob(glob_pattern) if file.is_file()]
if not files:
raise FileNotFoundError(f"在目录 {data_dir} 中未找到匹配 {glob_pattern} 的文件")
# 排序文件列表以确保一致的处理顺序
files = sorted(files)
# 创建数据集管道
ds = tf.data.Dataset.from_tensor_slices(files)
ds = ds.interleave(
_line_doc_extract,
cycle_length=cycle_length,
num_parallel_calls=tf.data.AUTOTUNE,
)
return ds
def _line_doc_extract(path: str) -> tf.data.Dataset:
"""Mini-c4 format: one document per line."""
return tf.data.TextLineDataset(path).map(
lambda x: tf.strings.regex_replace(x, r"\\n", "\n")
)