File size: 2,241 Bytes
a5fd608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import keras
import keras_hub
import pytest
from keras import layers


def test_vectorizer_specified_vocabulary_one():
    vectorizer = layers.TextVectorization(
        output_mode="int",
        split="character",
        output_sequence_length=10,
        standardize=None,
    )
    vocab = ["<pad>", "<unk>", "白", "日", "依", "山", "尽", "$"]
    vectorizer.set_vocabulary(vocab)

    sample_text = "白日依山尽"
    encoded = vectorizer([sample_text])
    assert (encoded[0].numpy() == [4, 5, 6, 7, 8, 0, 0, 0, 0, 0]).all(), (
        "编码结果比词表进了2位,因为前面的两个特殊标记没有被认可"
    )


def test_vectorizer_specified_vocabulary_two():
    vectorizer = layers.TextVectorization(
        output_mode="int",
        split="character",
        output_sequence_length=10,
        standardize=None,
    )
    vocab = ["", "[UNK]", "白", "日", "依", "山", "尽", "$"]
    vectorizer.set_vocabulary(vocab)

    sample_text = "白日依山尽"
    encoded = vectorizer([sample_text])
    assert (encoded[0].numpy() == [2, 3, 4, 5, 6, 0, 0, 0, 0, 0]).all(), (
        "编码结果与词表的序号一致"
    )


def test_batch_encode_decode():
    """测试批量编码和解码功能"""
    pytest.importorskip("tensorflow_text")
    vocabulary_file = keras.utils.get_file(
        origin="https://hf-mirror.com/mattdangerw/spiece/resolve/main/vocabulary.proto"
    )
    tokenizer = keras_hub.tokenizers.SentencePieceTokenizer(vocabulary_file)

    # 批量编码
    texts = ["", "Hi!", "Machine learning is amazing."]
    tokens = tokenizer.tokenize(texts)

    # 验证编码结果
    # SentencePiece 默认返回 RaggedTensor;但是传递 sequence_length 参数会返回密集 Tensor,不足的部分会被填充为 0.
    expected_tokens = [[], [6324, 29991], [6189, 6509, 338, 21863, 292, 29889]]
    assert tokens.to_list() == expected_tokens, f"编码结果不匹配: {tokens.to_list()}"

    # 批量解码
    decoded = tokenizer.detokenize(tokens)

    # 验证解码结果
    expected_decoded = [b"", b"Hi!", b"Machine learning is amazing."]
    assert decoded.numpy().tolist() == expected_decoded, (
        f"解码结果不匹配: {decoded.numpy().tolist()}"
    )