""" GPT模型的共享组件模块: - Positional Encoding - Transformer Decoder """ import keras from keras import layers, ops class PositionalEmbedding(keras.Layer): def __init__(self, sequence_length, input_dim, output_dim, **kwargs): super().__init__(**kwargs) self.token_embeddings = layers.Embedding(input_dim, output_dim) self.position_embeddings = layers.Embedding(sequence_length, output_dim) def call(self, inputs, reverse=False): if reverse: token_embeddings = self.token_embeddings.embeddings return ops.matmul(inputs, ops.transpose(token_embeddings)) positions = ops.cumsum(ops.ones_like(inputs), axis=-1) - 1 embedded_tokens = self.token_embeddings(inputs) embedded_positions = self.position_embeddings(positions) return embedded_tokens + embedded_positions class TransformerDecoder(keras.Layer): def __init__(self, hidden_dim, intermediate_dim, num_heads, **kwargs): super().__init__(**kwargs) self.hidden_dim = hidden_dim self.intermediate_dim = intermediate_dim key_dim = hidden_dim // num_heads # self-attention 层 self.self_attention = layers.MultiHeadAttention(num_heads, key_dim, dropout=0.1) self.self_attention_layernorm = layers.LayerNormalization() # feed-forward 层 self.feed_forward_1 = layers.Dense(intermediate_dim, activation="relu") self.feed_forward_2 = layers.Dense(hidden_dim) self.feed_forward_layernorm = layers.LayerNormalization() self.dropout = layers.Dropout(0.1) def call(self, inputs): # self-attention 计算 residual = x = inputs x = self.self_attention(query=x, key=x, value=x, use_causal_mask=True) x = self.dropout(x) x = x + residual x = self.self_attention_layernorm(x) # feed-forward 计算 residual = x x = self.feed_forward_1(x) x = self.feed_forward_2(x) x = self.dropout(x) x = x + residual x = self.feed_forward_layernorm(x) return x