Spaces:
Sleeping
Sleeping
| """诗歌数据集文档加载模块 | |
| 从 CSV 文件加载诗歌文本数据。 | |
| """ | |
| import glob | |
| import os | |
| import pathlib | |
| import tensorflow as tf | |
| def _parse_csv_line(line: tf.Tensor) -> tf.Tensor: | |
| """解析 CSV 行,返回内容列""" | |
| fields = tf.io.decode_csv( | |
| line, | |
| use_quote_delim=False, # 行内的引号是普通字符 | |
| record_defaults=["", "", "", "", ""], | |
| ) | |
| return fields[4] # 返回 '内容' 列的值 | |
| def doc_load(data_dir: pathlib.Path) -> tf.data.Dataset: | |
| """加载诗歌数据集 | |
| 从指定目录下的 CSV 文件中加载诗歌文本数据。 | |
| 每个 CSV 文件应该包含以下列:标题、作者、朝代、类型、内容。 | |
| Args: | |
| data_dir: 数据目录路径 | |
| Returns: | |
| TensorFlow Dataset,每个元素是诗歌内容字符串 | |
| """ | |
| csv_files = glob.glob(os.path.join(data_dir, "*.csv")) | |
| if not csv_files: | |
| raise ValueError(f"在目录 {data_dir} 中未找到任何 CSV 文件!") | |
| files_ds = tf.data.Dataset.from_tensor_slices(csv_files) | |
| csv_line_ds = files_ds.interleave( | |
| lambda csv_file: tf.data.TextLineDataset(csv_file).skip(1), | |
| cycle_length=1, | |
| ) | |
| return csv_line_ds.map(_parse_csv_line, num_parallel_calls=tf.data.AUTOTUNE).filter( | |
| lambda x: tf.strings.length(x) > 0 | |
| ) | |
| def doc_load_with_eot(data_dir: pathlib.Path) -> tf.data.Dataset: | |
| """加载诗歌数据集,每行末尾添加结束标记 | |
| Args: | |
| data_dir: 数据目录路径 | |
| Returns: | |
| TensorFlow Dataset,每个元素是带结束标记的诗歌内容 | |
| """ | |
| ds = doc_load(data_dir) | |
| return ds.map(lambda x: tf.strings.join([x, "$"])) | |