File size: 4,625 Bytes
a5fd608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import pytest
import tensorflow as tf
import numpy as np

from models.mini_gpt import GptModelBuilder
from models.rnn import RNNModelBuilder
from pipeline.base.model_builder import GenerationContext


def _sample_one(logits):
    return tf.constant([1], dtype="int32")


@pytest.mark.parametrize(
    "builder",
    [
        GptModelBuilder(
            hidden_dim=8,
            intermediate_dim=16,
            num_heads=2,
            num_layers=1
        ),
        RNNModelBuilder(
            num_layers=1,
            embedding_dim=8,
            hidden_dim=16
        )
    ]
)
def test_builder_training_and_inference_generate_match(builder):
    training_artifact = builder.build_training_artifact(
        vocab_size=32,
        sequence_length=16
    )
    inference_artifact = builder.build_inference_artifact(
        training_artifact=training_artifact
    )
    context = GenerationContext(
        end_of_text=99,
        max_length=6,
        sample_fn=_sample_one
    )

    training_result = training_artifact.generate(context, [2, 3, 4])
    inference_result = inference_artifact.generate(context, [2, 3, 4])

    assert training_result.token_ids == [2, 3, 4, 1, 1, 1]
    assert inference_result.token_ids == training_result.token_ids
    assert inference_result.stop_reason == training_result.stop_reason


def test_gpt_inference_artifact_reuses_training_artifact():
    builder = GptModelBuilder(
        hidden_dim=8,
        intermediate_dim=16,
        num_heads=2,
        num_layers=1
    )
    training_artifact = builder.build_training_artifact(
        vocab_size=32,
        sequence_length=16
    )

    inference_artifact = builder.build_inference_artifact(
        training_artifact=training_artifact
    )

    assert inference_artifact is training_artifact
    assert inference_artifact.model is training_artifact.model


def test_rnn_inference_artifact_uses_distinct_model():
    builder = RNNModelBuilder(
        num_layers=1,
        embedding_dim=8,
        hidden_dim=16
    )
    training_artifact = builder.build_training_artifact(
        vocab_size=32,
        sequence_length=16
    )

    inference_artifact = builder.build_inference_artifact(
        training_artifact=training_artifact
    )

    assert inference_artifact is not training_artifact
    assert inference_artifact.model is not training_artifact.model


def test_rnn_inference_model_outputs_logits_and_states():
    builder = RNNModelBuilder(
        num_layers=2,
        embedding_dim=8,
        hidden_dim=16
    )
    training_artifact = builder.build_training_artifact(
        vocab_size=32,
        sequence_length=16
    )
    inference_artifact = builder.build_inference_artifact(
        training_artifact=training_artifact
    )
    token_input = tf.constant([[2, 3, 4]], dtype="int32")
    state_inputs = []
    for _ in range(builder.num_layers):
        state_inputs.append(tf.zeros((1, builder.hidden_dim)))
        state_inputs.append(tf.zeros((1, builder.hidden_dim)))

    outputs = inference_artifact.model([token_input] + state_inputs, training=False)

    assert len(outputs) == 1 + builder.num_layers * 2
    assert outputs[0].shape == (1, 32)
    for state in outputs[1:]:
        assert state.shape == (1, builder.hidden_dim)


def test_rnn_inference_model_copies_training_weights():
    builder = RNNModelBuilder(
        num_layers=2,
        embedding_dim=8,
        hidden_dim=16
    )
    training_artifact = builder.build_training_artifact(
        vocab_size=32,
        sequence_length=16
    )

    inference_artifact = builder.build_inference_artifact(
        training_artifact=training_artifact
    )

    training_model = training_artifact.model
    inference_model = inference_artifact.model

    np.testing.assert_allclose(
        training_model.get_layer("embedding").get_weights()[0],
        inference_model.get_layer("embedding").get_weights()[0]
    )
    np.testing.assert_allclose(
        training_model.get_layer("logits").get_weights()[0],
        inference_model.get_layer("logits").get_weights()[0]
    )
    np.testing.assert_allclose(
        training_model.get_layer("logits").get_weights()[1],
        inference_model.get_layer("logits").get_weights()[1]
    )

    for i in range(builder.num_layers):
        training_lstm = training_model.get_layer(f"lstm_{i}")
        inference_lstm = inference_model.get_layer(f"lstm_{i}")

        for training_weights, inference_weights in zip(
            training_lstm.get_weights(),
            inference_lstm.get_weights()
        ):
            np.testing.assert_allclose(training_weights, inference_weights)