general-deep-learning / test /data /vocab_test.py
yetrun's picture
ver1: 实现深度学习训练框架,支持 Wiki GPT 与诗歌生成双任务
a5fd608
import pathlib
import tempfile
from keras import layers
from data import PoetryDataset
from data.common import build_vocab_from_dataset
from data.poetry.loader import doc_load
from data.poetry.tokenizer import load_vectorizer
from env.resolve import resolve_path
def test_create_and_load_vectorizer():
"""测试加载已保存的 TextVectorization 层并验证编码解码"""
data_dir = resolve_path("data/dev/poetry")
sequence_length = 100
dataset = doc_load(data_dir)
def _check_vectorizer_encode(vectorizer: layers.TextVectorization):
vocab = vectorizer.get_vocabulary()
sample_tensor = dataset.take(1).get_single_element()
sample_text = sample_tensor.numpy().decode("utf-8")
encoded = vectorizer([sample_text])
nonzero_indices = encoded[0].numpy()[encoded[0].numpy() > 0]
decoded = [vocab[idx] for idx in nonzero_indices]
decoded_text = "".join(decoded)
original_chars = list(sample_text[: len(decoded)])
decoded_chars = list(decoded_text)
assert decoded_chars == original_chars, "解码的字符应与原始文本一致"
with tempfile.TemporaryDirectory() as tmpdir:
vocab_path = pathlib.Path(tmpdir) / "poetry_vocab.txt"
vocab = build_vocab_from_dataset(dataset, vocab_path)
assert vocab_path.exists(), "词汇表文件应该被创建"
assert len(vocab) > 0, "词汇表不应为空"
loaded_vectorizer = load_vectorizer(vocab_path, sequence_length)
loaded_vocab = loaded_vectorizer.get_vocabulary()
assert len(loaded_vocab) == len(vocab), "加载的词汇表大小应一致"
_check_vectorizer_encode(loaded_vectorizer)
def test_poetry_dataset_tokenizer_bundle_contains_vocab_path():
data_dir = resolve_path("data/dev/poetry")
with tempfile.TemporaryDirectory() as tmpdir:
vocab_path = pathlib.Path(tmpdir) / "poetry_vocab.txt"
dataset = doc_load(data_dir)
build_vocab_from_dataset(dataset, vocab_path)
poetry_dataset = PoetryDataset(
data_dir=str(data_dir),
vocab_path=str(vocab_path),
sequence_length=100
)
tokenizer_info = poetry_dataset.tokenizer_bundle()
assert tokenizer_info.vocab_path == str(vocab_path)