Spaces:
Sleeping
Sleeping
File size: 1,740 Bytes
a5fd608 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | """Wiki 数据集文档加载模块
从 mini_c4 格式加载文档数据集。
"""
import pathlib
import tensorflow as tf
def doc_load(
data_dir: pathlib.Path, glob_pattern: str = "*", cycle_length: int = 32
) -> tf.data.Dataset:
"""加载并处理文档数据集为 TensorFlow Dataset。
递归查找指定目录下匹配 glob_pattern 的所有文件,使用 doc_extract 函数
将每个文件转换为 TensorFlow Dataset,然后使用 interleave 进行并行处理。
目录下的文件格式要求每行一个文档,其中的换行符使用 "\\n" 转义。
Args:
data_dir: 数据目录路径
glob_pattern: 文件匹配模式,如 "*.txt",默认为 "*" 匹配所有文件
cycle_length: interleave 的 cycle_length 参数,控制并行处理的文件数量,默认为 32
Returns:
合并后的 TensorFlow Dataset,包含所有文件处理后的数据
"""
# 获取所有文件(过滤掉目录),递归查找子目录
files = [str(file) for file in data_dir.rglob(glob_pattern) if file.is_file()]
if not files:
raise FileNotFoundError(f"在目录 {data_dir} 中未找到匹配 {glob_pattern} 的文件")
# 排序文件列表以确保一致的处理顺序
files = sorted(files)
# 创建数据集管道
ds = tf.data.Dataset.from_tensor_slices(files)
ds = ds.interleave(
_line_doc_extract,
cycle_length=cycle_length,
num_parallel_calls=tf.data.AUTOTUNE,
)
return ds
def _line_doc_extract(path: str) -> tf.data.Dataset:
"""Mini-c4 format: one document per line."""
return tf.data.TextLineDataset(path).map(
lambda x: tf.strings.regex_replace(x, r"\\n", "\n")
)
|