Spaces:
Sleeping
Sleeping
File size: 1,247 Bytes
a5fd608 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | """Wiki 数据集 token 转换模块
将文档数据集转换为训练用的 token 序列。
"""
from typing import Callable
import numpy as np
import tensorflow as tf
def transform(
ds: tf.data.Dataset,
tokenizer: Callable,
end_of_text: int,
sequence_length: int,
batch_size: int,
) -> tf.data.Dataset:
"""转换文档数据集为训练数据集
将文档转换为 token ID,添加结束标记,分割为固定长度的序列。
Args:
ds: 文档数据集
tokenizer: 分词器函数
end_of_text: 结束标记的 token ID
sequence_length: 序列长度
batch_size: 批次大小
Returns:
训练数据集,每个元素是 (input_ids, target_ids) 对
"""
ds = ds.map(tokenizer, num_parallel_calls=8)
# 将文档之间添加 end_of_text 标记分隔
ds = ds.map(lambda x: tf.concat([x, np.array([end_of_text])], -1))
# 重新设置样本大小为固定长度序列
ds = ds.rebatch(sequence_length + 1, drop_remainder=True)
# 构建输入和目标(偏移一位)
ds = ds.map(lambda x: (x[:-1], x[1:]))
# 重新设置批次大小并预取数据以提高性能
ds = ds.batch(batch_size).prefetch(8)
return ds
|