CodeLLM / model /architecture.py
devoppro's picture
Create model/architecture.py
eec13c0 verified
"""
CodeLLM - Custom Decoder-only Transformer Architecture
Built from scratch for code generation.
Architecture: GPT-style, 125M parameters
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Tuple
@dataclass
class CodeLLMConfig:
vocab_size: int = 50257
n_positions: int = 2048
n_embd: int = 768
n_layer: int = 12
n_head: int = 12
n_inner: int = 3072
dropout: float = 0.1
layer_norm_epsilon: float = 1e-5
initializer_range: float = 0.02
use_cache: bool = True
pad_token_id: int = 50256
bos_token_id: int = 50256
eos_token_id: int = 50256
tie_word_embeddings: bool = True
@property
def num_parameters(self):
embed = self.vocab_size * self.n_embd
attn = self.n_layer * (4 * self.n_embd * self.n_embd)
ffn = self.n_layer * (2 * self.n_embd * self.n_inner)
return embed + attn + ffn
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_seq_len: int = 2048, base: int = 10000):
super().__init__()
self.dim = dim
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self._build_cache(max_seq_len)
def _build_cache(self, seq_len: int):
t = torch.arange(seq_len, device=self.inv_freq.device).float()
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cache", emb.cos()[None, None, :, :])
self.register_buffer("sin_cache", emb.sin()[None, None, :, :])
def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int):
if seq_len > self.cos_cache.shape[2]:
self._build_cache(seq_len)
cos = self.cos_cache[:, :, :seq_len, :]
sin = self.sin_cache[:, :, :seq_len, :]
return apply_rotary(q, cos, sin), apply_rotary(k, cos, sin)
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat([-x2, x1], dim=-1)
def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
return (x * cos) + (rotate_half(x) * sin)
class CausalSelfAttention(nn.Module):
def __init__(self, config: CodeLLMConfig):
super().__init__()
assert config.n_embd % config.n_head == 0
self.n_head = config.n_head
self.n_embd = config.n_embd
self.head_dim = config.n_embd // config.n_head
self.dropout = config.dropout
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.attn_drop = nn.Dropout(config.dropout)
self.resid_drop = nn.Dropout(config.dropout)
self.rotary = RotaryEmbedding(self.head_dim, max_seq_len=config.n_positions)
self.register_buffer(
"bias",
torch.tril(torch.ones(config.n_positions, config.n_positions))
.view(1, 1, config.n_positions, config.n_positions),
)
def forward(self, x, attention_mask=None, past_key_value=None, use_cache=False):
B, T, C = x.size()
qkv = self.c_attn(x)
q, k, v = qkv.split(self.n_embd, dim=2)
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
q, k = self.rotary(q, k, seq_len=T)
if past_key_value is not None:
k = torch.cat([past_key_value[0], k], dim=2)
v = torch.cat([past_key_value[1], v], dim=2)
present = (k, v) if use_cache else None
if hasattr(F, "scaled_dot_product_attention"):
y = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=(past_key_value is None),
)
else:
scale = 1.0 / math.sqrt(self.head_dim)
attn = (q @ k.transpose(-2, -1)) * scale
kT = k.size(2)
causal_mask = self.bias[:, :, kT - T : kT, :kT]
attn = attn.masked_fill(causal_mask == 0, float("-inf"))
if attention_mask is not None:
attn = attn + attention_mask
attn = F.softmax(attn, dim=-1)
attn = self.attn_drop(attn)
y = attn @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.resid_drop(self.c_proj(y))
return y, present
class SwiGLUFFN(nn.Module):
def __init__(self, config: CodeLLMConfig):
super().__init__()
hidden = config.n_inner
self.w1 = nn.Linear(config.n_embd, hidden, bias=False)
self.w2 = nn.Linear(config.n_embd, hidden, bias=False)
self.w3 = nn.Linear(hidden, config.n_embd, bias=False)
self.drop = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.drop(self.w3(F.silu(self.w1(x)) * self.w2(x)))
class TransformerBlock(nn.Module):
def __init__(self, config: CodeLLMConfig):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.ffn = SwiGLUFFN(config)
def forward(self, x, attention_mask=None, past_key_value=None, use_cache=False):
attn_out, present = self.attn(
self.ln_1(x),
attention_mask=attention_mask,
past_key_value=past_key_value,
use_cache=use_cache,
)
x = x + attn_out
x = x + self.ffn(self.ln_2(x))
return x, present
class CodeLLM(nn.Module):
def __init__(self, config: CodeLLMConfig):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
drop = nn.Dropout(config.dropout),
h = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)]),
ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
if config.tie_word_embeddings:
self.lm_head.weight = self.transformer.wte.weight
self.apply(self._init_weights)
for name, p in self.named_parameters():
if name.endswith("c_proj.weight"):
nn.init.normal_(p, mean=0.0, std=config.initializer_range / math.sqrt(2 * config.n_layer))
print(f"CodeLLM initialized | params: {self.num_parameters:,}")
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
@property
def num_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def forward(self, input_ids, attention_mask=None, labels=None, past_key_values=None, use_cache=False):
B, T = input_ids.size()
x = self.transformer.wte(input_ids)
x = self.transformer.drop(x)
presents = []
for i, block in enumerate(self.transformer.h):
past_kv = past_key_values[i] if past_key_values else None
x, present = block(x, attention_mask=attention_mask, past_key_value=past_kv, use_cache=use_cache)
if use_cache:
presents.append(present)
x = self.transformer.ln_f(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
return {"loss": loss, "logits": logits, "past_key_values": presents if use_cache else None}
@torch.no_grad()
def generate(self, input_ids, max_new_tokens=256, temperature=0.8, top_k=50, top_p=0.95, eos_token_id=None):
self.eval()
past_key_values = None
eos = eos_token_id or self.config.eos_token_id
for _ in range(max_new_tokens):
input_slice = input_ids if past_key_values is None else input_ids[:, -1:]
out = self.forward(input_slice, past_key_values=past_key_values, use_cache=True)
past_key_values = out["past_key_values"]
logits = out["logits"][:, -1, :] / temperature
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float("-inf")
if top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
remove = cumprobs - F.softmax(sorted_logits, dim=-1) > top_p
sorted_logits[remove] = float("-inf")
logits.scatter_(1, sorted_idx, sorted_logits)
probs = F.softmax(logits, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_tok], dim=1)
if (next_tok == eos).all():
break
return input_ids