File size: 2,235 Bytes
a5fd608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""数据集抽象基类模块

定义 DataBundle 抽象基类,统一数据集的接口规范。
每个具体的数据集(如 Wiki、诗歌)都应该继承此类并实现相应方法。
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Callable, Optional

import tensorflow as tf


@dataclass
class TokenizerBundle:
    """分词器信息包装类

    将分词器相关的属性打包在一起,简化 DataBundle 接口。
    """

    tokenizer: Callable
    decode: Callable
    end_of_text: int
    vocab_size: int
    vocab_path: str = ""


@dataclass
class DataBundle(ABC):
    """数据集抽象基类

    将数据加载、分词、统计等功能绑定在一起,提供统一的数据集接口。

    Usage:
        dataset = WikiDataset(data_dir="~/data/wiki")
        doc_ds = dataset.doc_ds()
        tokens_ds = dataset.tokens_ds(seq_length=256, batch_size=32)
        dataset.stat()
    """

    data_dir: str
    sequence_length: int = 256

    @abstractmethod
    def doc_ds(self) -> tf.data.Dataset:
        """返回原始文档数据集

        Returns:
            TensorFlow Dataset,每个元素是一个文档字符串
        """
        pass

    @abstractmethod
    def tokens_ds(self, seq_length: int, batch_size: int) -> tf.data.Dataset:
        """返回 tokenized 数据集

        将原始文档转换为 token ID 序列,并分割为训练样本。

        Args:
            seq_length: 序列长度
            batch_size: 批次大小

        Returns:
            TensorFlow Dataset,每个元素是 (input_ids, target_ids) 对
        """
        pass

    @abstractmethod
    def tokenizer_bundle(self) -> TokenizerBundle:
        """返回分词器信息"""
        pass

    def stat(self, seq_length: int | None = None) -> None:
        """打印数据集统计信息

        Args:
            seq_length: 序列长度,用于估算训练样本数
        """
        from data.common import collect_stats

        info = self.tokenizer_bundle()
        stats = collect_stats(
            name=self.__class__.__name__, loader=self.doc_ds, tokenizer=info.tokenizer
        )
        stats.print_report(seq_length=seq_length)