File size: 5,027 Bytes
a5fd608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""数据集 Runner 公共模块

提供通用的数据集测试和词汇表生成功能。

Usage:
    # 在各自 runner.py 中实例化
    from data.runner import DatasetRunner
    from data.poetry.dataset import PoetryDataset
    from env.resolve import resolve, resolve_saved, resolve_env

    dataset = PoetryDataset(
        data_dir=str(resolve_env(resolve("data/dev/poetry"), resolve("~/data/Poetry/诗歌数据集"))),
        vocab_path=str(resolve_env(resolve_saved("poetry/vocab.txt"), resolve("~/data/Poetry/vocabulary.txt"))),
        sequence_length=100,
    )
    runner = DatasetRunner(dataset=dataset, name="poetry")
    runner()
"""

from data.base import DataBundle
from data.common import build_vocab_from_dataset
from env.resolve import resolve_saved
from env.runner import ActionRunner


class DatasetRunner(ActionRunner):
    """数据集 Runner

    提供通用的数据集测试和词汇表生成功能。

    Args:
        dataset: 数据集实例(PoetryDataset 或 WikiDataset)
        name: 数据集英文名称(如 "poetry", "wiki")
        max_docs: 测试时显示的文档数量,默认 5
        max_samples: 测试时显示的 token 样本数量,默认 3
        max_doc_chars: 文档显示的最大字符数,默认 200
        max_text_display: token 文本显示的最大字符数,默认 80

    Usage:
        runner = DatasetRunner(dataset=poetry_dataset, name="poetry")
        runner.test_dataset()  # 或 runner.build_vocab()
    """

    # 中英文名称映射
    NAME_MAP = {
        "poetry": "诗歌",
        "wiki": "Wiki",
    }

    def __init__(
        self,
        dataset: DataBundle,
        name: str,
        max_docs: int = 5,
        max_samples: int = 3,
        max_doc_chars: int = 200,
        max_text_display: int = 80,
    ):
        self.dataset = dataset
        self.name = name
        self.display_name = self.NAME_MAP.get(name, name)
        self.vocab_path = resolve_saved(f"vocab/{name}/vocab.txt")
        self.max_docs = max_docs
        self.max_samples = max_samples
        self.max_doc_chars = max_doc_chars
        self.max_text_display = max_text_display

    def build_vocab(self) -> None:
        """生成字符词汇表"""
        print(f"正在加载数据集...")
        ds = self.dataset.doc_ds()

        print(f"正在保存词汇表到: {self.vocab_path}")
        vocab = build_vocab_from_dataset(ds, self.vocab_path)

        print(f"词汇表大小: {len(vocab)}")
        print("完成!")

    def test_dataset(self) -> None:
        """测试数据集"""
        print("\n" + "=" * 60)
        print(f"{self.display_name} 数据集测试")
        print("=" * 60)

        self._view_documents(self.dataset.doc_ds())
        self._view_tokens(self.dataset)
        self._show_vocab_info(self.dataset.tokenizer_bundle())

        print("\n" + "=" * 60)
        print("测试完成")
        print("=" * 60)

    def _view_documents(self, doc_ds) -> None:
        """查看原始文档"""
        print("\n【原始文档查看】")
        print("-" * 60)
        count = 0
        for doc in doc_ds.take(self.max_docs):
            count += 1
            text = doc.numpy().decode("utf-8")
            if len(text) > self.max_doc_chars:
                text = text[: self.max_doc_chars] + "..."
            print(f"\n第 {count} 个文档:")
            print(f"  {text}")
        print(f"\n共显示 {count} 个文档")

    def _view_tokens(self, dataset) -> None:
        """查看 tokenized 数据"""
        print("\n【Tokenized 数据查看】")
        print("-" * 60)

        tokenizer_info = dataset.tokenizer_bundle()
        tokens_ds = dataset.tokens_ds(seq_length=dataset.sequence_length, batch_size=1)

        count = 0
        for batch_input, batch_target in tokens_ds.take(self.max_samples):
            count += 1
            input_ids = batch_input[0].numpy()
            target_ids = batch_target[0].numpy()

            input_text = tokenizer_info.decode(input_ids.tolist())
            target_text = tokenizer_info.decode(target_ids.tolist())

            if len(input_text) > self.max_text_display:
                input_text = input_text[: self.max_text_display] + "..."
            if len(target_text) > self.max_text_display:
                target_text = target_text[: self.max_text_display] + "..."

            print(f"\n第 {count} 个样本:")
            print(f"  输入 tokens: {input_ids[:20]}... (长度: {len(input_ids)})")
            print(f"  目标 tokens: {target_ids[:20]}... (长度: {len(target_ids)})")
            print(f"  输入文本: {input_text}")
            print(f"  目标文本: {target_text}")
        print(f"\n共显示 {count} 个样本")

    @staticmethod
    def _show_vocab_info(tokenizer_info) -> None:
        """显示词汇表信息"""
        print("\n【词汇表信息】")
        print("-" * 60)
        print(f"  词汇表大小: {tokenizer_info.vocab_size}")
        print(f"  结束标记 ID: {tokenizer_info.end_of_text}")