| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutput |
| from .configuration_gpt2workshop import GPT2WorkshopConfig |
| from transformers.generation import GenerationMixin |
|
|
|
|
| def build_rope_cache(seq_len, head_dim, device, theta=10000.0): |
| freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)) |
| positions = torch.arange(seq_len, device=device).float() |
| angles = torch.outer(positions, freqs) |
| return torch.cos(angles), torch.sin(angles) |
|
|
|
|
| def apply_rotary_embeddings(x, rope_cos, rope_sin): |
| cos = rope_cos[:x.shape[2], :].unsqueeze(0).unsqueeze(0) |
| sin = rope_sin[:x.shape[2], :].unsqueeze(0).unsqueeze(0) |
| even, odd = x[..., 0::2], x[..., 1::2] |
| return torch.stack((even * cos - odd * sin, even * sin + odd * cos), dim=-1).flatten(-2) |
|
|
|
|
| def relu_squared(x): |
| return F.relu(x).square() |
|
|
|
|
| def soft_cap_logits(logits, cap=30.0): |
| return cap * torch.tanh(logits / cap) |
|
|
|
|
| class MultiHeadAttention(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.num_heads = config.num_heads |
| self.head_dim = config.head_dim |
| self.query_projection = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False) |
| self.key_projection = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False) |
| self.value_projection = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False) |
| self.output_projection = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False) |
| self.attn_dropout_rate = config.dropout |
|
|
| def forward(self, x, rope_cos, rope_sin): |
| batch_size, seq_len, _ = x.shape |
| q = self.query_projection(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.key_projection(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| v = self.value_projection(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| q = apply_rotary_embeddings(q, rope_cos, rope_sin) |
| k = apply_rotary_embeddings(k, rope_cos, rope_sin) |
| dropout_p = self.attn_dropout_rate if self.training else 0.0 |
| attn_output = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=dropout_p) |
| attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) |
| return self.output_projection(attn_output) |
|
|
|
|
| class FeedForwardNetwork(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| ffn_dim = config.hidden_dim * config.ffn_expansion |
| self.up_projection = nn.Linear(config.hidden_dim, ffn_dim, bias=False) |
| self.down_projection = nn.Linear(ffn_dim, config.hidden_dim, bias=False) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x): |
| return self.dropout(self.down_projection(relu_squared(self.up_projection(x)))) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.attention_norm = nn.RMSNorm(config.hidden_dim, eps=1e-6) |
| self.attention = MultiHeadAttention(config) |
| self.ffn_norm = nn.RMSNorm(config.hidden_dim, eps=1e-6) |
| self.feed_forward = FeedForwardNetwork(config) |
| self.attention_residual_dropout = nn.Dropout(config.dropout) |
| self.ffn_residual_dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x, rope_cos, rope_sin): |
| x = x + self.attention_residual_dropout(self.attention(self.attention_norm(x), rope_cos, rope_sin)) |
| x = x + self.ffn_residual_dropout(self.feed_forward(self.ffn_norm(x))) |
| return x |
|
|
|
|
| class GPT2WorkshopForCausalLM(PreTrainedModel, GenerationMixin): |
| config_class = GPT2WorkshopConfig |
| _tied_weights_keys = {} |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
| self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_dim) |
| self.embedding_dropout = nn.Dropout(config.dropout) |
| self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)]) |
| self.final_norm = nn.RMSNorm(config.hidden_dim, eps=1e-6) |
| rope_cos, rope_sin = build_rope_cache(config.context_length, config.head_dim, device="cpu", theta=config.rope_theta) |
| self.register_buffer("rope_cos", rope_cos) |
| self.register_buffer("rope_sin", rope_sin) |
|
|
| def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): |
| x = self.embedding_dropout(self.token_embedding(input_ids)) |
| for layer in self.layers: |
| x = layer(x, self.rope_cos, self.rope_sin) |
| x = self.final_norm(x) |
| logits = soft_cap_logits(F.linear(x, self.token_embedding.weight), self.config.logit_soft_cap) |
| loss = None |
| if labels is not None: |
| loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), labels.reshape(-1), ignore_index=-100) |
| return CausalLMOutput(loss=loss, logits=logits) |
|
|
| def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| return {"input_ids": input_ids} |
|
|
| @property |
| def all_tied_weights_keys(self): |
| return {} |