""" CodeLLM - Custom Decoder-only Transformer Architecture Built from scratch for code generation. Architecture: GPT-style, 125M parameters """ import math import torch import torch.nn as nn import torch.nn.functional as F from dataclasses import dataclass from typing import Optional, Tuple @dataclass class CodeLLMConfig: vocab_size: int = 50257 n_positions: int = 2048 n_embd: int = 768 n_layer: int = 12 n_head: int = 12 n_inner: int = 3072 dropout: float = 0.1 layer_norm_epsilon: float = 1e-5 initializer_range: float = 0.02 use_cache: bool = True pad_token_id: int = 50256 bos_token_id: int = 50256 eos_token_id: int = 50256 tie_word_embeddings: bool = True @property def num_parameters(self): embed = self.vocab_size * self.n_embd attn = self.n_layer * (4 * self.n_embd * self.n_embd) ffn = self.n_layer * (2 * self.n_embd * self.n_inner) return embed + attn + ffn class RotaryEmbedding(nn.Module): def __init__(self, dim: int, max_seq_len: int = 2048, base: int = 10000): super().__init__() self.dim = dim inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) self._build_cache(max_seq_len) def _build_cache(self, seq_len: int): t = torch.arange(seq_len, device=self.inv_freq.device).float() freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat([freqs, freqs], dim=-1) self.register_buffer("cos_cache", emb.cos()[None, None, :, :]) self.register_buffer("sin_cache", emb.sin()[None, None, :, :]) def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int): if seq_len > self.cos_cache.shape[2]: self._build_cache(seq_len) cos = self.cos_cache[:, :, :seq_len, :] sin = self.sin_cache[:, :, :seq_len, :] return apply_rotary(q, cos, sin), apply_rotary(k, cos, sin) def rotate_half(x: torch.Tensor) -> torch.Tensor: x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return torch.cat([-x2, x1], dim=-1) def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: return (x * cos) + (rotate_half(x) * sin) class CausalSelfAttention(nn.Module): def __init__(self, config: CodeLLMConfig): super().__init__() assert config.n_embd % config.n_head == 0 self.n_head = config.n_head self.n_embd = config.n_embd self.head_dim = config.n_embd // config.n_head self.dropout = config.dropout self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) self.attn_drop = nn.Dropout(config.dropout) self.resid_drop = nn.Dropout(config.dropout) self.rotary = RotaryEmbedding(self.head_dim, max_seq_len=config.n_positions) self.register_buffer( "bias", torch.tril(torch.ones(config.n_positions, config.n_positions)) .view(1, 1, config.n_positions, config.n_positions), ) def forward(self, x, attention_mask=None, past_key_value=None, use_cache=False): B, T, C = x.size() qkv = self.c_attn(x) q, k, v = qkv.split(self.n_embd, dim=2) q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) q, k = self.rotary(q, k, seq_len=T) if past_key_value is not None: k = torch.cat([past_key_value[0], k], dim=2) v = torch.cat([past_key_value[1], v], dim=2) present = (k, v) if use_cache else None if hasattr(F, "scaled_dot_product_attention"): y = F.scaled_dot_product_attention( q, k, v, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=(past_key_value is None), ) else: scale = 1.0 / math.sqrt(self.head_dim) attn = (q @ k.transpose(-2, -1)) * scale kT = k.size(2) causal_mask = self.bias[:, :, kT - T : kT, :kT] attn = attn.masked_fill(causal_mask == 0, float("-inf")) if attention_mask is not None: attn = attn + attention_mask attn = F.softmax(attn, dim=-1) attn = self.attn_drop(attn) y = attn @ v y = y.transpose(1, 2).contiguous().view(B, T, C) y = self.resid_drop(self.c_proj(y)) return y, present class SwiGLUFFN(nn.Module): def __init__(self, config: CodeLLMConfig): super().__init__() hidden = config.n_inner self.w1 = nn.Linear(config.n_embd, hidden, bias=False) self.w2 = nn.Linear(config.n_embd, hidden, bias=False) self.w3 = nn.Linear(hidden, config.n_embd, bias=False) self.drop = nn.Dropout(config.dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.drop(self.w3(F.silu(self.w1(x)) * self.w2(x))) class TransformerBlock(nn.Module): def __init__(self, config: CodeLLMConfig): super().__init__() self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.attn = CausalSelfAttention(config) self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ffn = SwiGLUFFN(config) def forward(self, x, attention_mask=None, past_key_value=None, use_cache=False): attn_out, present = self.attn( self.ln_1(x), attention_mask=attention_mask, past_key_value=past_key_value, use_cache=use_cache, ) x = x + attn_out x = x + self.ffn(self.ln_2(x)) return x, present class CodeLLM(nn.Module): def __init__(self, config: CodeLLMConfig): super().__init__() self.config = config self.transformer = nn.ModuleDict(dict( wte = nn.Embedding(config.vocab_size, config.n_embd), drop = nn.Dropout(config.dropout), h = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)]), ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon), )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) if config.tie_word_embeddings: self.lm_head.weight = self.transformer.wte.weight self.apply(self._init_weights) for name, p in self.named_parameters(): if name.endswith("c_proj.weight"): nn.init.normal_(p, mean=0.0, std=config.initializer_range / math.sqrt(2 * config.n_layer)) print(f"CodeLLM initialized | params: {self.num_parameters:,}") def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) @property def num_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) def forward(self, input_ids, attention_mask=None, labels=None, past_key_values=None, use_cache=False): B, T = input_ids.size() x = self.transformer.wte(input_ids) x = self.transformer.drop(x) presents = [] for i, block in enumerate(self.transformer.h): past_kv = past_key_values[i] if past_key_values else None x, present = block(x, attention_mask=attention_mask, past_key_value=past_kv, use_cache=use_cache) if use_cache: presents.append(present) x = self.transformer.ln_f(x) logits = self.lm_head(x) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, ) return {"loss": loss, "logits": logits, "past_key_values": presents if use_cache else None} @torch.no_grad() def generate(self, input_ids, max_new_tokens=256, temperature=0.8, top_k=50, top_p=0.95, eos_token_id=None): self.eval() past_key_values = None eos = eos_token_id or self.config.eos_token_id for _ in range(max_new_tokens): input_slice = input_ids if past_key_values is None else input_ids[:, -1:] out = self.forward(input_slice, past_key_values=past_key_values, use_cache=True) past_key_values = out["past_key_values"] logits = out["logits"][:, -1, :] / temperature if top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float("-inf") if top_p < 1.0: sorted_logits, sorted_idx = torch.sort(logits, descending=True) cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) remove = cumprobs - F.softmax(sorted_logits, dim=-1) > top_p sorted_logits[remove] = float("-inf") logits.scatter_(1, sorted_idx, sorted_logits) probs = F.softmax(logits, dim=-1) next_tok = torch.multinomial(probs, num_samples=1) input_ids = torch.cat([input_ids, next_tok], dim=1) if (next_tok == eos).all(): break return input_ids