Spaces:
Running
Running
| import pathlib | |
| import tempfile | |
| from keras import layers | |
| from data import PoetryDataset | |
| from data.common import build_vocab_from_dataset | |
| from data.poetry.loader import doc_load | |
| from data.poetry.tokenizer import load_vectorizer | |
| from env.resolve import resolve_path | |
| def test_create_and_load_vectorizer(): | |
| """测试加载已保存的 TextVectorization 层并验证编码解码""" | |
| data_dir = resolve_path("data/dev/poetry") | |
| sequence_length = 100 | |
| dataset = doc_load(data_dir) | |
| def _check_vectorizer_encode(vectorizer: layers.TextVectorization): | |
| vocab = vectorizer.get_vocabulary() | |
| sample_tensor = dataset.take(1).get_single_element() | |
| sample_text = sample_tensor.numpy().decode("utf-8") | |
| encoded = vectorizer([sample_text]) | |
| nonzero_indices = encoded[0].numpy()[encoded[0].numpy() > 0] | |
| decoded = [vocab[idx] for idx in nonzero_indices] | |
| decoded_text = "".join(decoded) | |
| original_chars = list(sample_text[: len(decoded)]) | |
| decoded_chars = list(decoded_text) | |
| assert decoded_chars == original_chars, "解码的字符应与原始文本一致" | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| vocab_path = pathlib.Path(tmpdir) / "poetry_vocab.txt" | |
| vocab = build_vocab_from_dataset(dataset, vocab_path) | |
| assert vocab_path.exists(), "词汇表文件应该被创建" | |
| assert len(vocab) > 0, "词汇表不应为空" | |
| loaded_vectorizer = load_vectorizer(vocab_path, sequence_length) | |
| loaded_vocab = loaded_vectorizer.get_vocabulary() | |
| assert len(loaded_vocab) == len(vocab), "加载的词汇表大小应一致" | |
| _check_vectorizer_encode(loaded_vectorizer) | |
| def test_poetry_dataset_tokenizer_bundle_contains_vocab_path(): | |
| data_dir = resolve_path("data/dev/poetry") | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| vocab_path = pathlib.Path(tmpdir) / "poetry_vocab.txt" | |
| dataset = doc_load(data_dir) | |
| build_vocab_from_dataset(dataset, vocab_path) | |
| poetry_dataset = PoetryDataset( | |
| data_dir=str(data_dir), | |
| vocab_path=str(vocab_path), | |
| sequence_length=100 | |
| ) | |
| tokenizer_info = poetry_dataset.tokenizer_bundle() | |
| assert tokenizer_info.vocab_path == str(vocab_path) | |