Spaces:
Sleeping
Sleeping
| """Wiki 数据集 token 转换模块 | |
| 将文档数据集转换为训练用的 token 序列。 | |
| """ | |
| from typing import Callable | |
| import numpy as np | |
| import tensorflow as tf | |
| def transform( | |
| ds: tf.data.Dataset, | |
| tokenizer: Callable, | |
| end_of_text: int, | |
| sequence_length: int, | |
| batch_size: int, | |
| ) -> tf.data.Dataset: | |
| """转换文档数据集为训练数据集 | |
| 将文档转换为 token ID,添加结束标记,分割为固定长度的序列。 | |
| Args: | |
| ds: 文档数据集 | |
| tokenizer: 分词器函数 | |
| end_of_text: 结束标记的 token ID | |
| sequence_length: 序列长度 | |
| batch_size: 批次大小 | |
| Returns: | |
| 训练数据集,每个元素是 (input_ids, target_ids) 对 | |
| """ | |
| ds = ds.map(tokenizer, num_parallel_calls=8) | |
| # 将文档之间添加 end_of_text 标记分隔 | |
| ds = ds.map(lambda x: tf.concat([x, np.array([end_of_text])], -1)) | |
| # 重新设置样本大小为固定长度序列 | |
| ds = ds.rebatch(sequence_length + 1, drop_remainder=True) | |
| # 构建输入和目标(偏移一位) | |
| ds = ds.map(lambda x: (x[:-1], x[1:])) | |
| # 重新设置批次大小并预取数据以提高性能 | |
| ds = ds.batch(batch_size).prefetch(8) | |
| return ds | |