import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutput from .configuration_gpt2workshop import GPT2WorkshopConfig from transformers.generation import GenerationMixin def build_rope_cache(seq_len, head_dim, device, theta=10000.0): freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)) positions = torch.arange(seq_len, device=device).float() angles = torch.outer(positions, freqs) return torch.cos(angles), torch.sin(angles) def apply_rotary_embeddings(x, rope_cos, rope_sin): cos = rope_cos[:x.shape[2], :].unsqueeze(0).unsqueeze(0) sin = rope_sin[:x.shape[2], :].unsqueeze(0).unsqueeze(0) even, odd = x[..., 0::2], x[..., 1::2] return torch.stack((even * cos - odd * sin, even * sin + odd * cos), dim=-1).flatten(-2) def relu_squared(x): return F.relu(x).square() def soft_cap_logits(logits, cap=30.0): return cap * torch.tanh(logits / cap) class MultiHeadAttention(nn.Module): def __init__(self, config): super().__init__() self.num_heads = config.num_heads self.head_dim = config.head_dim self.query_projection = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False) self.key_projection = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False) self.value_projection = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False) self.output_projection = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False) self.attn_dropout_rate = config.dropout def forward(self, x, rope_cos, rope_sin): batch_size, seq_len, _ = x.shape q = self.query_projection(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = self.key_projection(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) v = self.value_projection(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) q = apply_rotary_embeddings(q, rope_cos, rope_sin) k = apply_rotary_embeddings(k, rope_cos, rope_sin) dropout_p = self.attn_dropout_rate if self.training else 0.0 attn_output = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=dropout_p) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) return self.output_projection(attn_output) class FeedForwardNetwork(nn.Module): def __init__(self, config): super().__init__() ffn_dim = config.hidden_dim * config.ffn_expansion self.up_projection = nn.Linear(config.hidden_dim, ffn_dim, bias=False) self.down_projection = nn.Linear(ffn_dim, config.hidden_dim, bias=False) self.dropout = nn.Dropout(config.dropout) def forward(self, x): return self.dropout(self.down_projection(relu_squared(self.up_projection(x)))) class TransformerBlock(nn.Module): def __init__(self, config): super().__init__() self.attention_norm = nn.RMSNorm(config.hidden_dim, eps=1e-6) self.attention = MultiHeadAttention(config) self.ffn_norm = nn.RMSNorm(config.hidden_dim, eps=1e-6) self.feed_forward = FeedForwardNetwork(config) self.attention_residual_dropout = nn.Dropout(config.dropout) self.ffn_residual_dropout = nn.Dropout(config.dropout) def forward(self, x, rope_cos, rope_sin): x = x + self.attention_residual_dropout(self.attention(self.attention_norm(x), rope_cos, rope_sin)) x = x + self.ffn_residual_dropout(self.feed_forward(self.ffn_norm(x))) return x class GPT2WorkshopForCausalLM(PreTrainedModel, GenerationMixin): config_class = GPT2WorkshopConfig _tied_weights_keys = {} def __init__(self, config): super().__init__(config) self.config = config self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_dim) self.embedding_dropout = nn.Dropout(config.dropout) self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)]) self.final_norm = nn.RMSNorm(config.hidden_dim, eps=1e-6) rope_cos, rope_sin = build_rope_cache(config.context_length, config.head_dim, device="cpu", theta=config.rope_theta) self.register_buffer("rope_cos", rope_cos) self.register_buffer("rope_sin", rope_sin) def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): x = self.embedding_dropout(self.token_embedding(input_ids)) for layer in self.layers: x = layer(x, self.rope_cos, self.rope_sin) x = self.final_norm(x) logits = soft_cap_logits(F.linear(x, self.token_embedding.weight), self.config.logit_soft_cap) loss = None if labels is not None: loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), labels.reshape(-1), ignore_index=-100) return CausalLMOutput(loss=loss, logits=logits) def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} @property def all_tied_weights_keys(self): return {}