Spaces:
Sleeping
Sleeping
File size: 4,364 Bytes
a5fd608 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | """数据集共享工具模块
提供数据集统计、报告生成等共享功能。
"""
import pathlib
from dataclasses import dataclass
from typing import Callable
import numpy as np
import tensorflow as tf
from keras import layers
@dataclass
class DatasetStats:
"""数据集统计结果"""
name: str
doc_count: int
total_chars: int
total_tokens: int
max_length: int
median_length: int
def print_report(self, seq_length: int | None = 256):
"""打印统一格式的统计报表
Args:
seq_length: 序列长度,用于估算训练样本数。
为 None 时表示不切割,一个文档一个样本。
"""
avg_chars = self.total_chars / self.doc_count if self.doc_count > 0 else 0
avg_tokens = self.total_tokens / self.doc_count if self.doc_count > 0 else 0
print()
print("=" * 60)
print(f"{self.name} 数据集统计")
print("=" * 60)
print(f"{'文档数:':<20} {self.doc_count:>15,}")
print(f"{'总字符数:':<20} {self.total_chars:>15,}")
print(f"{'总 Token 数:':<20} {self.total_tokens:>15,}")
print("-" * 60)
print(f"{'平均每文档字符数:':<20} {avg_chars:>15.1f}")
print(f"{'平均每文档 Token 数:':<20} {avg_tokens:>15.1f}")
print(f"{'最长文档字符数:':<20} {self.max_length:>15,}")
print(f"{'文档长度中位数:':<20} {self.median_length:>15,}")
print("=" * 60)
if self.total_tokens > 0:
print()
if seq_length is None:
print(f"训练样本数: {self.doc_count:,} 个 (一个文档一个样本)")
else:
print(f"训练样本预估 (seq={seq_length}):")
print(f" 可生成约 {self.total_tokens // seq_length:,} 个训练样本")
def collect_stats(
name: str, loader: Callable[[], tf.data.Dataset], tokenizer: Callable
) -> DatasetStats:
"""从 DatasetLoader 收集统计数据
Args:
name: 数据集名称(用于报表显示)
loader: 返回 tf.data.Dataset 的加载器函数
tokenizer: 分词器函数,接收文本返回 token ID 列表
Returns:
DatasetStats 统计结果对象
"""
ds = loader()
doc_count = 0
total_chars = 0
total_tokens = 0
lengths = []
for item in ds:
text = item.numpy().decode("utf-8")
if not text.strip():
continue
doc_count += 1
total_chars += len(text)
lengths.append(len(text))
# Token 统计,过滤掉末尾的 padding (值为 0 的 token)
try:
import keras
token_ids = keras.ops.convert_to_numpy(tokenizer(text))
except ImportError:
# Fallback: assume tokenizer returns numpy array directly
token_ids = np.array(tokenizer(text))
# 只去掉末尾的 0,保留中间内容(包括中间的 OOV/padding)
valid_tokens = np.trim_zeros(token_ids, "b")
total_tokens += len(valid_tokens)
return DatasetStats(
name=name,
doc_count=doc_count,
total_chars=total_chars,
total_tokens=total_tokens,
max_length=max(lengths) if lengths else 0,
median_length=int(np.median(lengths)) if lengths else 0,
)
def save_vocabulary(vocab: list[str], vocab_path: pathlib.Path) -> None:
"""保存词汇表到文件
Args:
vocab: 词汇表列表
vocab_path: 保存路径
"""
vocab_path.parent.mkdir(parents=True, exist_ok=True)
with open(vocab_path, "w", encoding="utf-8") as f:
for char in vocab:
written = char if char != "\n" else r"\n"
f.write(written + "\n")
def build_vocab_from_dataset(
doc_ds: tf.data.Dataset, vocab_path: pathlib.Path
) -> list[str]:
"""从文档数据集构建词汇表
Args:
doc_ds: 文档数据集
vocab_path: 词汇表保存路径
Returns:
词汇表列表
"""
vectorizer = layers.TextVectorization(
output_mode="int", split="character", standardize=None
)
vectorizer.adapt(doc_ds, batch_size=128)
vocab = vectorizer.get_vocabulary()
if "$" not in vocab:
vocab = [*vocab, "$"]
save_vocabulary(vocab, vocab_path)
return vocab
|