File size: 4,364 Bytes
a5fd608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""数据集共享工具模块

提供数据集统计、报告生成等共享功能。
"""

import pathlib
from dataclasses import dataclass
from typing import Callable

import numpy as np
import tensorflow as tf
from keras import layers


@dataclass
class DatasetStats:
    """数据集统计结果"""

    name: str
    doc_count: int
    total_chars: int
    total_tokens: int
    max_length: int
    median_length: int

    def print_report(self, seq_length: int | None = 256):
        """打印统一格式的统计报表

        Args:
            seq_length: 序列长度,用于估算训练样本数。
                       为 None 时表示不切割,一个文档一个样本。
        """
        avg_chars = self.total_chars / self.doc_count if self.doc_count > 0 else 0
        avg_tokens = self.total_tokens / self.doc_count if self.doc_count > 0 else 0

        print()
        print("=" * 60)
        print(f"{self.name} 数据集统计")
        print("=" * 60)
        print(f"{'文档数:':<20} {self.doc_count:>15,}")
        print(f"{'总字符数:':<20} {self.total_chars:>15,}")
        print(f"{'总 Token 数:':<20} {self.total_tokens:>15,}")
        print("-" * 60)
        print(f"{'平均每文档字符数:':<20} {avg_chars:>15.1f}")
        print(f"{'平均每文档 Token 数:':<20} {avg_tokens:>15.1f}")
        print(f"{'最长文档字符数:':<20} {self.max_length:>15,}")
        print(f"{'文档长度中位数:':<20} {self.median_length:>15,}")
        print("=" * 60)

        if self.total_tokens > 0:
            print()
            if seq_length is None:
                print(f"训练样本数: {self.doc_count:,} 个 (一个文档一个样本)")
            else:
                print(f"训练样本预估 (seq={seq_length}):")
                print(f"  可生成约 {self.total_tokens // seq_length:,} 个训练样本")


def collect_stats(
    name: str, loader: Callable[[], tf.data.Dataset], tokenizer: Callable
) -> DatasetStats:
    """从 DatasetLoader 收集统计数据

    Args:
        name: 数据集名称(用于报表显示)
        loader: 返回 tf.data.Dataset 的加载器函数
        tokenizer: 分词器函数,接收文本返回 token ID 列表

    Returns:
        DatasetStats 统计结果对象
    """
    ds = loader()

    doc_count = 0
    total_chars = 0
    total_tokens = 0
    lengths = []

    for item in ds:
        text = item.numpy().decode("utf-8")
        if not text.strip():
            continue

        doc_count += 1
        total_chars += len(text)
        lengths.append(len(text))

        # Token 统计,过滤掉末尾的 padding (值为 0 的 token)
        try:
            import keras

            token_ids = keras.ops.convert_to_numpy(tokenizer(text))
        except ImportError:
            # Fallback: assume tokenizer returns numpy array directly
            token_ids = np.array(tokenizer(text))

        # 只去掉末尾的 0,保留中间内容(包括中间的 OOV/padding)
        valid_tokens = np.trim_zeros(token_ids, "b")
        total_tokens += len(valid_tokens)

    return DatasetStats(
        name=name,
        doc_count=doc_count,
        total_chars=total_chars,
        total_tokens=total_tokens,
        max_length=max(lengths) if lengths else 0,
        median_length=int(np.median(lengths)) if lengths else 0,
    )


def save_vocabulary(vocab: list[str], vocab_path: pathlib.Path) -> None:
    """保存词汇表到文件

    Args:
        vocab: 词汇表列表
        vocab_path: 保存路径
    """
    vocab_path.parent.mkdir(parents=True, exist_ok=True)
    with open(vocab_path, "w", encoding="utf-8") as f:
        for char in vocab:
            written = char if char != "\n" else r"\n"
            f.write(written + "\n")


def build_vocab_from_dataset(
    doc_ds: tf.data.Dataset, vocab_path: pathlib.Path
) -> list[str]:
    """从文档数据集构建词汇表

    Args:
        doc_ds: 文档数据集
        vocab_path: 词汇表保存路径

    Returns:
        词汇表列表
    """
    vectorizer = layers.TextVectorization(
        output_mode="int", split="character", standardize=None
    )
    vectorizer.adapt(doc_ds, batch_size=128)

    vocab = vectorizer.get_vocabulary()
    if "$" not in vocab:
        vocab = [*vocab, "$"]

    save_vocabulary(vocab, vocab_path)
    return vocab