| """ |
| CodeLLM - Custom Decoder-only Transformer Architecture |
| Built from scratch for code generation. |
| Architecture: GPT-style, 125M parameters |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dataclasses import dataclass |
| from typing import Optional, Tuple |
|
|
|
|
| @dataclass |
| class CodeLLMConfig: |
| vocab_size: int = 50257 |
| n_positions: int = 2048 |
| n_embd: int = 768 |
| n_layer: int = 12 |
| n_head: int = 12 |
| n_inner: int = 3072 |
| dropout: float = 0.1 |
| layer_norm_epsilon: float = 1e-5 |
| initializer_range: float = 0.02 |
| use_cache: bool = True |
| pad_token_id: int = 50256 |
| bos_token_id: int = 50256 |
| eos_token_id: int = 50256 |
| tie_word_embeddings: bool = True |
|
|
| @property |
| def num_parameters(self): |
| embed = self.vocab_size * self.n_embd |
| attn = self.n_layer * (4 * self.n_embd * self.n_embd) |
| ffn = self.n_layer * (2 * self.n_embd * self.n_inner) |
| return embed + attn + ffn |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| def __init__(self, dim: int, max_seq_len: int = 2048, base: int = 10000): |
| super().__init__() |
| self.dim = dim |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq) |
| self._build_cache(max_seq_len) |
|
|
| def _build_cache(self, seq_len: int): |
| t = torch.arange(seq_len, device=self.inv_freq.device).float() |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
| emb = torch.cat([freqs, freqs], dim=-1) |
| self.register_buffer("cos_cache", emb.cos()[None, None, :, :]) |
| self.register_buffer("sin_cache", emb.sin()[None, None, :, :]) |
|
|
| def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int): |
| if seq_len > self.cos_cache.shape[2]: |
| self._build_cache(seq_len) |
| cos = self.cos_cache[:, :, :seq_len, :] |
| sin = self.sin_cache[:, :, :seq_len, :] |
| return apply_rotary(q, cos, sin), apply_rotary(k, cos, sin) |
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] |
| return torch.cat([-x2, x1], dim=-1) |
|
|
|
|
| def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| return (x * cos) + (rotate_half(x) * sin) |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, config: CodeLLMConfig): |
| super().__init__() |
| assert config.n_embd % config.n_head == 0 |
| self.n_head = config.n_head |
| self.n_embd = config.n_embd |
| self.head_dim = config.n_embd // config.n_head |
| self.dropout = config.dropout |
| self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) |
| self.attn_drop = nn.Dropout(config.dropout) |
| self.resid_drop = nn.Dropout(config.dropout) |
| self.rotary = RotaryEmbedding(self.head_dim, max_seq_len=config.n_positions) |
| self.register_buffer( |
| "bias", |
| torch.tril(torch.ones(config.n_positions, config.n_positions)) |
| .view(1, 1, config.n_positions, config.n_positions), |
| ) |
|
|
| def forward(self, x, attention_mask=None, past_key_value=None, use_cache=False): |
| B, T, C = x.size() |
| qkv = self.c_attn(x) |
| q, k, v = qkv.split(self.n_embd, dim=2) |
| q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| q, k = self.rotary(q, k, seq_len=T) |
| if past_key_value is not None: |
| k = torch.cat([past_key_value[0], k], dim=2) |
| v = torch.cat([past_key_value[1], v], dim=2) |
| present = (k, v) if use_cache else None |
| if hasattr(F, "scaled_dot_product_attention"): |
| y = F.scaled_dot_product_attention( |
| q, k, v, |
| attn_mask=attention_mask, |
| dropout_p=self.dropout if self.training else 0.0, |
| is_causal=(past_key_value is None), |
| ) |
| else: |
| scale = 1.0 / math.sqrt(self.head_dim) |
| attn = (q @ k.transpose(-2, -1)) * scale |
| kT = k.size(2) |
| causal_mask = self.bias[:, :, kT - T : kT, :kT] |
| attn = attn.masked_fill(causal_mask == 0, float("-inf")) |
| if attention_mask is not None: |
| attn = attn + attention_mask |
| attn = F.softmax(attn, dim=-1) |
| attn = self.attn_drop(attn) |
| y = attn @ v |
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
| y = self.resid_drop(self.c_proj(y)) |
| return y, present |
|
|
|
|
| class SwiGLUFFN(nn.Module): |
| def __init__(self, config: CodeLLMConfig): |
| super().__init__() |
| hidden = config.n_inner |
| self.w1 = nn.Linear(config.n_embd, hidden, bias=False) |
| self.w2 = nn.Linear(config.n_embd, hidden, bias=False) |
| self.w3 = nn.Linear(hidden, config.n_embd, bias=False) |
| self.drop = nn.Dropout(config.dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.drop(self.w3(F.silu(self.w1(x)) * self.w2(x))) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, config: CodeLLMConfig): |
| super().__init__() |
| self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
| self.attn = CausalSelfAttention(config) |
| self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
| self.ffn = SwiGLUFFN(config) |
|
|
| def forward(self, x, attention_mask=None, past_key_value=None, use_cache=False): |
| attn_out, present = self.attn( |
| self.ln_1(x), |
| attention_mask=attention_mask, |
| past_key_value=past_key_value, |
| use_cache=use_cache, |
| ) |
| x = x + attn_out |
| x = x + self.ffn(self.ln_2(x)) |
| return x, present |
|
|
|
|
| class CodeLLM(nn.Module): |
| def __init__(self, config: CodeLLMConfig): |
| super().__init__() |
| self.config = config |
| self.transformer = nn.ModuleDict(dict( |
| wte = nn.Embedding(config.vocab_size, config.n_embd), |
| drop = nn.Dropout(config.dropout), |
| h = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)]), |
| ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon), |
| )) |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| if config.tie_word_embeddings: |
| self.lm_head.weight = self.transformer.wte.weight |
| self.apply(self._init_weights) |
| for name, p in self.named_parameters(): |
| if name.endswith("c_proj.weight"): |
| nn.init.normal_(p, mean=0.0, std=config.initializer_range / math.sqrt(2 * config.n_layer)) |
| print(f"CodeLLM initialized | params: {self.num_parameters:,}") |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) |
|
|
| @property |
| def num_parameters(self): |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
| def forward(self, input_ids, attention_mask=None, labels=None, past_key_values=None, use_cache=False): |
| B, T = input_ids.size() |
| x = self.transformer.wte(input_ids) |
| x = self.transformer.drop(x) |
| presents = [] |
| for i, block in enumerate(self.transformer.h): |
| past_kv = past_key_values[i] if past_key_values else None |
| x, present = block(x, attention_mask=attention_mask, past_key_value=past_kv, use_cache=use_cache) |
| if use_cache: |
| presents.append(present) |
| x = self.transformer.ln_f(x) |
| logits = self.lm_head(x) |
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=-100, |
| ) |
| return {"loss": loss, "logits": logits, "past_key_values": presents if use_cache else None} |
|
|
| @torch.no_grad() |
| def generate(self, input_ids, max_new_tokens=256, temperature=0.8, top_k=50, top_p=0.95, eos_token_id=None): |
| self.eval() |
| past_key_values = None |
| eos = eos_token_id or self.config.eos_token_id |
| for _ in range(max_new_tokens): |
| input_slice = input_ids if past_key_values is None else input_ids[:, -1:] |
| out = self.forward(input_slice, past_key_values=past_key_values, use_cache=True) |
| past_key_values = out["past_key_values"] |
| logits = out["logits"][:, -1, :] / temperature |
| if top_k > 0: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = float("-inf") |
| if top_p < 1.0: |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True) |
| cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| remove = cumprobs - F.softmax(sorted_logits, dim=-1) > top_p |
| sorted_logits[remove] = float("-inf") |
| logits.scatter_(1, sorted_idx, sorted_logits) |
| probs = F.softmax(logits, dim=-1) |
| next_tok = torch.multinomial(probs, num_samples=1) |
| input_ids = torch.cat([input_ids, next_tok], dim=1) |
| if (next_tok == eos).all(): |
| break |
| return input_ids |