Spaces:
Sleeping
Sleeping
| """Wiki 数据集文档加载模块 | |
| 从 mini_c4 格式加载文档数据集。 | |
| """ | |
| import pathlib | |
| import tensorflow as tf | |
| def doc_load( | |
| data_dir: pathlib.Path, glob_pattern: str = "*", cycle_length: int = 32 | |
| ) -> tf.data.Dataset: | |
| """加载并处理文档数据集为 TensorFlow Dataset。 | |
| 递归查找指定目录下匹配 glob_pattern 的所有文件,使用 doc_extract 函数 | |
| 将每个文件转换为 TensorFlow Dataset,然后使用 interleave 进行并行处理。 | |
| 目录下的文件格式要求每行一个文档,其中的换行符使用 "\\n" 转义。 | |
| Args: | |
| data_dir: 数据目录路径 | |
| glob_pattern: 文件匹配模式,如 "*.txt",默认为 "*" 匹配所有文件 | |
| cycle_length: interleave 的 cycle_length 参数,控制并行处理的文件数量,默认为 32 | |
| Returns: | |
| 合并后的 TensorFlow Dataset,包含所有文件处理后的数据 | |
| """ | |
| # 获取所有文件(过滤掉目录),递归查找子目录 | |
| files = [str(file) for file in data_dir.rglob(glob_pattern) if file.is_file()] | |
| if not files: | |
| raise FileNotFoundError(f"在目录 {data_dir} 中未找到匹配 {glob_pattern} 的文件") | |
| # 排序文件列表以确保一致的处理顺序 | |
| files = sorted(files) | |
| # 创建数据集管道 | |
| ds = tf.data.Dataset.from_tensor_slices(files) | |
| ds = ds.interleave( | |
| _line_doc_extract, | |
| cycle_length=cycle_length, | |
| num_parallel_calls=tf.data.AUTOTUNE, | |
| ) | |
| return ds | |
| def _line_doc_extract(path: str) -> tf.data.Dataset: | |
| """Mini-c4 format: one document per line.""" | |
| return tf.data.TextLineDataset(path).map( | |
| lambda x: tf.strings.regex_replace(x, r"\\n", "\n") | |
| ) | |