general-deep-learning / models /mini_gpt /model_builder.py
yetrun's picture
ver1: 实现深度学习训练框架,支持 Wiki GPT 与诗歌生成双任务
a5fd608
from dataclasses import dataclass
from functools import partial
import keras
from keras import layers
from models.mini_gpt.gpt_components import PositionalEmbedding, TransformerDecoder
from pipeline.base.generation import generate_with_training_model
from pipeline.base.model_builder import ModelArtifact
@dataclass
class GptModelBuilder:
hidden_dim: int
intermediate_dim: int
num_heads: int
num_layers: int
def build_training_artifact(
self,
vocab_size: int,
sequence_length: int
) -> ModelArtifact:
inputs = keras.Input(shape=(None,), dtype="int32", name="inputs")
embedding = PositionalEmbedding(
sequence_length,
vocab_size,
self.hidden_dim,
name="embedding"
)
x = embedding(inputs)
x = layers.LayerNormalization(name="input_layer_norm")(x)
for i in range(self.num_layers):
decoder = TransformerDecoder(
self.hidden_dim,
self.intermediate_dim,
self.num_heads,
name=f"decoder_{i}"
)
x = decoder(x)
outputs = embedding(x, reverse=True)
model = keras.Model(inputs, outputs, name="mini_gpt")
return ModelArtifact(
model=model,
generate=partial(generate_with_training_model, model)
)
def build_inference_artifact(
self,
training_artifact: ModelArtifact
) -> ModelArtifact:
return training_artifact