general-deep-learning / data /wiki /transformer.py
yetrun's picture
ver1: 实现深度学习训练框架,支持 Wiki GPT 与诗歌生成双任务
a5fd608
"""Wiki 数据集 token 转换模块
将文档数据集转换为训练用的 token 序列。
"""
from typing import Callable
import numpy as np
import tensorflow as tf
def transform(
ds: tf.data.Dataset,
tokenizer: Callable,
end_of_text: int,
sequence_length: int,
batch_size: int,
) -> tf.data.Dataset:
"""转换文档数据集为训练数据集
将文档转换为 token ID,添加结束标记,分割为固定长度的序列。
Args:
ds: 文档数据集
tokenizer: 分词器函数
end_of_text: 结束标记的 token ID
sequence_length: 序列长度
batch_size: 批次大小
Returns:
训练数据集,每个元素是 (input_ids, target_ids) 对
"""
ds = ds.map(tokenizer, num_parallel_calls=8)
# 将文档之间添加 end_of_text 标记分隔
ds = ds.map(lambda x: tf.concat([x, np.array([end_of_text])], -1))
# 重新设置样本大小为固定长度序列
ds = ds.rebatch(sequence_length + 1, drop_remainder=True)
# 构建输入和目标(偏移一位)
ds = ds.map(lambda x: (x[:-1], x[1:]))
# 重新设置批次大小并预取数据以提高性能
ds = ds.batch(batch_size).prefetch(8)
return ds