File size: 1,537 Bytes
a5fd608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from dataclasses import dataclass
from functools import partial

import keras
from keras import layers

from models.mini_gpt.gpt_components import PositionalEmbedding, TransformerDecoder
from pipeline.base.generation import generate_with_training_model
from pipeline.base.model_builder import ModelArtifact


@dataclass
class GptModelBuilder:
    hidden_dim: int
    intermediate_dim: int
    num_heads: int
    num_layers: int

    def build_training_artifact(
        self,
        vocab_size: int,
        sequence_length: int
    ) -> ModelArtifact:
        inputs = keras.Input(shape=(None,), dtype="int32", name="inputs")
        embedding = PositionalEmbedding(
            sequence_length,
            vocab_size,
            self.hidden_dim,
            name="embedding"
        )
        x = embedding(inputs)
        x = layers.LayerNormalization(name="input_layer_norm")(x)

        for i in range(self.num_layers):
            decoder = TransformerDecoder(
                self.hidden_dim,
                self.intermediate_dim,
                self.num_heads,
                name=f"decoder_{i}"
            )
            x = decoder(x)

        outputs = embedding(x, reverse=True)
        model = keras.Model(inputs, outputs, name="mini_gpt")
        return ModelArtifact(
            model=model,
            generate=partial(generate_with_training_model, model)
        )

    def build_inference_artifact(
        self,
        training_artifact: ModelArtifact
    ) -> ModelArtifact:
        return training_artifact