import torch import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin from transformers.modeling_outputs import CausalLMOutput class WorkshopGPTConfig(PretrainedConfig): model_type = "workshop_gpt" attribute_map = {"num_hidden_layers": "n_layer"} def __init__(self, n_layer=12, n_head=12, n_embd=768, vocab_size=50304, block_size=1024, n_inner=3072, rope_theta=10000.0, **kwargs): super().__init__(**kwargs) self.n_layer = n_layer self.n_head = n_head self.n_embd = n_embd self.vocab_size = vocab_size self.block_size = block_size self.n_inner = n_inner self.rope_theta = rope_theta class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.scale = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x): return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps).type_as(x) * self.scale class RotaryPositionalEmbeddings(nn.Module): def __init__(self, dim, max_seq_len=1024, base=10000.0): super().__init__() self.dim = dim self.max_seq_len = max_seq_len self.base = base self.cache = None def _build_cache(self, seq_len, device): theta = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)) seq = torch.arange(seq_len, device=device) freqs = torch.outer(seq, theta) self.cache = torch.stack([freqs.cos(), freqs.sin()], dim=-1) def forward(self, x, *, input_pos=None): seq_len = x.shape[-2] if self.cache is None or self.cache.shape[0] < seq_len or self.cache.device != x.device: self._build_cache(max(seq_len, self.max_seq_len), x.device) cache = self.cache[:seq_len] if input_pos is None else self.cache[input_pos] x1, x2 = x.float().unflatten(-1, (-1, 2)).unbind(-1) cos, sin = cache.unbind(-1) shape = [1] * (x.ndim - 2) + list(cos.shape) cos, sin = cos.view(*shape), sin.view(*shape) return torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1).flatten(-2).type_as(x) class ReluSquaredMLP(nn.Module): def __init__(self, dim, hidden_dim): super().__init__() self.fc_in = nn.Linear(dim, hidden_dim, bias=False) self.fc_out = nn.Linear(hidden_dim, dim, bias=False) def forward(self, x): return self.fc_out(F.relu(self.fc_in(x)).square()) class CausalSelfAttention(nn.Module): def __init__(self, n_embd, n_head, head_dim, rope): super().__init__() self.n_head = n_head self.head_dim = head_dim self.q_proj = nn.Linear(n_embd, n_embd, bias=False) self.k_proj = nn.Linear(n_embd, n_embd, bias=False) self.v_proj = nn.Linear(n_embd, n_embd, bias=False) self.output_proj = nn.Linear(n_embd, n_embd, bias=False) self.rope = rope def forward(self, x): B, T, C = x.shape q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) q, k = self.rope(q), self.rope(k) y = F.scaled_dot_product_attention(q, k, v, is_causal=True) return self.output_proj(y.transpose(1, 2).contiguous().view(B, T, C)) class TransformerBlock(nn.Module): def __init__(self, config): super().__init__() hd = config.n_embd // config.n_head rope = RotaryPositionalEmbeddings(hd, config.block_size, config.rope_theta) self.sa_norm = RMSNorm(config.n_embd) self.attn = CausalSelfAttention(config.n_embd, config.n_head, hd, rope) self.mlp_norm = RMSNorm(config.n_embd) self.mlp = ReluSquaredMLP(config.n_embd, config.n_inner) def forward(self, x): x = x + self.attn(self.sa_norm(x)) return x + self.mlp(self.mlp_norm(x)) class WorkshopGPTForCausalLM(PreTrainedModel, GenerationMixin): config_class = WorkshopGPTConfig _tied_weights_keys = {} def __init__(self, config): super().__init__(config) self.tok_embeddings = nn.Embedding(config.vocab_size, config.n_embd) self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)]) self.norm = RMSNorm(config.n_embd) def forward(self, input_ids, **kwargs): x = self.tok_embeddings(input_ids) for layer in self.layers: x = layer(x) return CausalLMOutput(logits=F.linear(self.norm(x), self.tok_embeddings.weight)) def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} @property def all_tied_weights_keys(self): return {}