File size: 1,247 Bytes
a5fd608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""Wiki 数据集 token 转换模块

将文档数据集转换为训练用的 token 序列。
"""

from typing import Callable

import numpy as np
import tensorflow as tf


def transform(
    ds: tf.data.Dataset,
    tokenizer: Callable,
    end_of_text: int,
    sequence_length: int,
    batch_size: int,
) -> tf.data.Dataset:
    """转换文档数据集为训练数据集

    将文档转换为 token ID,添加结束标记,分割为固定长度的序列。

    Args:
        ds: 文档数据集
        tokenizer: 分词器函数
        end_of_text: 结束标记的 token ID
        sequence_length: 序列长度
        batch_size: 批次大小

    Returns:
        训练数据集,每个元素是 (input_ids, target_ids) 对
    """
    ds = ds.map(tokenizer, num_parallel_calls=8)

    # 将文档之间添加 end_of_text 标记分隔
    ds = ds.map(lambda x: tf.concat([x, np.array([end_of_text])], -1))

    # 重新设置样本大小为固定长度序列
    ds = ds.rebatch(sequence_length + 1, drop_remainder=True)

    # 构建输入和目标(偏移一位)
    ds = ds.map(lambda x: (x[:-1], x[1:]))

    # 重新设置批次大小并预取数据以提高性能
    ds = ds.batch(batch_size).prefetch(8)

    return ds