Spaces:
Running
Running
| """ | |
| GPT模型的共享组件模块: | |
| - Positional Encoding | |
| - Transformer Decoder | |
| """ | |
| import keras | |
| from keras import layers, ops | |
| class PositionalEmbedding(keras.Layer): | |
| def __init__(self, sequence_length, input_dim, output_dim, **kwargs): | |
| super().__init__(**kwargs) | |
| self.token_embeddings = layers.Embedding(input_dim, output_dim) | |
| self.position_embeddings = layers.Embedding(sequence_length, output_dim) | |
| def call(self, inputs, reverse=False): | |
| if reverse: | |
| token_embeddings = self.token_embeddings.embeddings | |
| return ops.matmul(inputs, ops.transpose(token_embeddings)) | |
| positions = ops.cumsum(ops.ones_like(inputs), axis=-1) - 1 | |
| embedded_tokens = self.token_embeddings(inputs) | |
| embedded_positions = self.position_embeddings(positions) | |
| return embedded_tokens + embedded_positions | |
| class TransformerDecoder(keras.Layer): | |
| def __init__(self, hidden_dim, intermediate_dim, num_heads, **kwargs): | |
| super().__init__(**kwargs) | |
| self.hidden_dim = hidden_dim | |
| self.intermediate_dim = intermediate_dim | |
| key_dim = hidden_dim // num_heads | |
| # self-attention 层 | |
| self.self_attention = layers.MultiHeadAttention(num_heads, key_dim, dropout=0.1) | |
| self.self_attention_layernorm = layers.LayerNormalization() | |
| # feed-forward 层 | |
| self.feed_forward_1 = layers.Dense(intermediate_dim, activation="relu") | |
| self.feed_forward_2 = layers.Dense(hidden_dim) | |
| self.feed_forward_layernorm = layers.LayerNormalization() | |
| self.dropout = layers.Dropout(0.1) | |
| def call(self, inputs): | |
| # self-attention 计算 | |
| residual = x = inputs | |
| x = self.self_attention(query=x, key=x, value=x, use_causal_mask=True) | |
| x = self.dropout(x) | |
| x = x + residual | |
| x = self.self_attention_layernorm(x) | |
| # feed-forward 计算 | |
| residual = x | |
| x = self.feed_forward_1(x) | |
| x = self.feed_forward_2(x) | |
| x = self.dropout(x) | |
| x = x + residual | |
| x = self.feed_forward_layernorm(x) | |
| return x | |