File size: 5,208 Bytes
63b7820 a3b70ca 63b7820 a3b70ca 63b7820 8b251ce 63b7820 8b251ce 63b7820 8b251ce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput
from .configuration_gpt2workshop import GPT2WorkshopConfig
from transformers.generation import GenerationMixin
def build_rope_cache(seq_len, head_dim, device, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
positions = torch.arange(seq_len, device=device).float()
angles = torch.outer(positions, freqs)
return torch.cos(angles), torch.sin(angles)
def apply_rotary_embeddings(x, rope_cos, rope_sin):
cos = rope_cos[:x.shape[2], :].unsqueeze(0).unsqueeze(0)
sin = rope_sin[:x.shape[2], :].unsqueeze(0).unsqueeze(0)
even, odd = x[..., 0::2], x[..., 1::2]
return torch.stack((even * cos - odd * sin, even * sin + odd * cos), dim=-1).flatten(-2)
def relu_squared(x):
return F.relu(x).square()
def soft_cap_logits(logits, cap=30.0):
return cap * torch.tanh(logits / cap)
class MultiHeadAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.num_heads = config.num_heads
self.head_dim = config.head_dim
self.query_projection = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)
self.key_projection = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)
self.value_projection = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)
self.output_projection = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)
self.attn_dropout_rate = config.dropout
def forward(self, x, rope_cos, rope_sin):
batch_size, seq_len, _ = x.shape
q = self.query_projection(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.key_projection(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = self.value_projection(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
q = apply_rotary_embeddings(q, rope_cos, rope_sin)
k = apply_rotary_embeddings(k, rope_cos, rope_sin)
dropout_p = self.attn_dropout_rate if self.training else 0.0
attn_output = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=dropout_p)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
return self.output_projection(attn_output)
class FeedForwardNetwork(nn.Module):
def __init__(self, config):
super().__init__()
ffn_dim = config.hidden_dim * config.ffn_expansion
self.up_projection = nn.Linear(config.hidden_dim, ffn_dim, bias=False)
self.down_projection = nn.Linear(ffn_dim, config.hidden_dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.down_projection(relu_squared(self.up_projection(x))))
class TransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.attention_norm = nn.RMSNorm(config.hidden_dim, eps=1e-6)
self.attention = MultiHeadAttention(config)
self.ffn_norm = nn.RMSNorm(config.hidden_dim, eps=1e-6)
self.feed_forward = FeedForwardNetwork(config)
self.attention_residual_dropout = nn.Dropout(config.dropout)
self.ffn_residual_dropout = nn.Dropout(config.dropout)
def forward(self, x, rope_cos, rope_sin):
x = x + self.attention_residual_dropout(self.attention(self.attention_norm(x), rope_cos, rope_sin))
x = x + self.ffn_residual_dropout(self.feed_forward(self.ffn_norm(x)))
return x
class GPT2WorkshopForCausalLM(PreTrainedModel, GenerationMixin):
config_class = GPT2WorkshopConfig
_tied_weights_keys = {}
def __init__(self, config):
super().__init__(config)
self.config = config
self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_dim)
self.embedding_dropout = nn.Dropout(config.dropout)
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)])
self.final_norm = nn.RMSNorm(config.hidden_dim, eps=1e-6)
rope_cos, rope_sin = build_rope_cache(config.context_length, config.head_dim, device="cpu", theta=config.rope_theta)
self.register_buffer("rope_cos", rope_cos)
self.register_buffer("rope_sin", rope_sin)
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
x = self.embedding_dropout(self.token_embedding(input_ids))
for layer in self.layers:
x = layer(x, self.rope_cos, self.rope_sin)
x = self.final_norm(x)
logits = soft_cap_logits(F.linear(x, self.token_embedding.weight), self.config.logit_soft_cap)
loss = None
if labels is not None:
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), labels.reshape(-1), ignore_index=-100)
return CausalLMOutput(loss=loss, logits=logits)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
@property
def all_tied_weights_keys(self):
return {} |