Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| from functools import partial | |
| import keras | |
| from keras import layers | |
| from models.mini_gpt.gpt_components import PositionalEmbedding, TransformerDecoder | |
| from pipeline.base.generation import generate_with_training_model | |
| from pipeline.base.model_builder import ModelArtifact | |
| class GptModelBuilder: | |
| hidden_dim: int | |
| intermediate_dim: int | |
| num_heads: int | |
| num_layers: int | |
| def build_training_artifact( | |
| self, | |
| vocab_size: int, | |
| sequence_length: int | |
| ) -> ModelArtifact: | |
| inputs = keras.Input(shape=(None,), dtype="int32", name="inputs") | |
| embedding = PositionalEmbedding( | |
| sequence_length, | |
| vocab_size, | |
| self.hidden_dim, | |
| name="embedding" | |
| ) | |
| x = embedding(inputs) | |
| x = layers.LayerNormalization(name="input_layer_norm")(x) | |
| for i in range(self.num_layers): | |
| decoder = TransformerDecoder( | |
| self.hidden_dim, | |
| self.intermediate_dim, | |
| self.num_heads, | |
| name=f"decoder_{i}" | |
| ) | |
| x = decoder(x) | |
| outputs = embedding(x, reverse=True) | |
| model = keras.Model(inputs, outputs, name="mini_gpt") | |
| return ModelArtifact( | |
| model=model, | |
| generate=partial(generate_with_training_model, model) | |
| ) | |
| def build_inference_artifact( | |
| self, | |
| training_artifact: ModelArtifact | |
| ) -> ModelArtifact: | |
| return training_artifact | |