| import math |
| import os |
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel |
|
|
| from .configuration_nanogpt import NanoGPTConfig |
|
|
|
|
| def _rms_norm(x: torch.Tensor) -> torch.Tensor: |
| return F.rms_norm(x, (x.size(-1),)) |
|
|
|
|
| def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| assert x.ndim == 4 |
| d = x.shape[3] // 2 |
| x1, x2 = x[..., :d], x[..., d:] |
| y1 = x1 * cos + x2 * sin |
| y2 = x1 * (-sin) + x2 * cos |
| out = torch.cat([y1, y2], 3) |
| return out.to(x.dtype) |
|
|
|
|
| def _repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: |
| if n_rep == 1: |
| return x |
| bs, n_kv_heads, slen, head_dim = x.shape |
| return ( |
| x[:, :, None, :, :] |
| .expand(bs, n_kv_heads, n_rep, slen, head_dim) |
| .reshape(bs, n_kv_heads * n_rep, slen, head_dim) |
| ) |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, config: NanoGPTConfig, layer_idx: int): |
| super().__init__() |
| self.layer_idx = layer_idx |
| self.n_head = config.n_head |
| self.n_kv_head = config.n_kv_head |
| self.n_embd = config.n_embd |
| self.head_dim = self.n_embd // self.n_head |
| assert self.n_embd % self.n_head == 0 |
| assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0 |
| self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) |
| self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) |
| self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) |
| self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) |
|
|
| def forward(self, x: torch.Tensor, cos_sin, kv_cache=None) -> torch.Tensor: |
| B, T, C = x.size() |
| q = self.c_q(x).view(B, T, self.n_head, self.head_dim) |
| k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) |
| v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) |
| cos, sin = cos_sin |
| q, k = _apply_rotary_emb(q, cos, sin), _apply_rotary_emb(k, cos, sin) |
| q, k = _rms_norm(q), _rms_norm(k) |
| q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) |
| Tq = q.size(2) |
| Tk = k.size(2) |
| nrep = self.n_head // self.n_kv_head |
| k, v = _repeat_kv(k, nrep), _repeat_kv(v, nrep) |
| if Tq == Tk: |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
| elif Tq == 1: |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=False) |
| else: |
| attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) |
| prefix_len = Tk - Tq |
| if prefix_len > 0: |
| attn_mask[:, :prefix_len] = True |
| attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)) |
| y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) |
| y = y.transpose(1, 2).contiguous().view(B, T, -1) |
| y = self.c_proj(y) |
| return y |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, config: NanoGPTConfig): |
| super().__init__() |
| self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) |
| self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.c_fc(x) |
| x = F.relu(x).square() |
| x = self.c_proj(x) |
| return x |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, config: NanoGPTConfig, layer_idx: int): |
| super().__init__() |
| self.attn = CausalSelfAttention(config, layer_idx) |
| self.mlp = MLP(config) |
|
|
| def forward(self, x: torch.Tensor, cos_sin, kv_cache=None) -> torch.Tensor: |
| x = x + self.attn(_rms_norm(x), cos_sin, kv_cache) |
| x = x + self.mlp(_rms_norm(x)) |
| return x |
|
|
|
|
| class NanoGPTModel(PreTrainedModel): |
| config_class = NanoGPTConfig |
|
|
| def __init__(self, config: NanoGPTConfig): |
| super().__init__(config) |
| self.transformer = nn.ModuleDict({ |
| "wte": nn.Embedding(config.vocab_size, config.n_embd), |
| "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), |
| }) |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| self.rotary_seq_len = config.sequence_len * 10 |
| head_dim = config.n_embd // config.n_head |
| cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) |
| self.register_buffer("cos", cos, persistent=False) |
| self.register_buffer("sin", sin, persistent=False) |
| |
| self.transformer.wte.to(dtype=torch.bfloat16) |
|
|
| |
| self.post_init() |
|
|
| def _init_weights(self, module: nn.Module): |
| if isinstance(module, nn.Linear): |
| fan_out = module.weight.size(0) |
| fan_in = module.weight.size(1) |
| std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in)) |
| torch.nn.init.normal_(module.weight, mean=0.0, std=std) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=1.0) |
|
|
| def _precompute_rotary_embeddings(self, seq_len: int, head_dim: int, base: int = 10000, device=None): |
| if device is None: |
| device = self.transformer.wte.weight.device |
| |
| if device.type == 'meta': |
| device = torch.device('cpu') |
| channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) |
| inv_freq = 1.0 / (base ** (channel_range / head_dim)) |
| t = torch.arange(seq_len, dtype=torch.float32, device=device) |
| freqs = torch.outer(t, inv_freq) |
| cos, sin = freqs.cos(), freqs.sin() |
| cos, sin = cos.bfloat16(), sin.bfloat16() |
| cos, sin = cos[None, :, None, :], sin[None, :, None, :] |
| return cos, sin |
|
|
| def forward(self, input_ids: torch.Tensor, labels=None, **kwargs): |
| idx = input_ids |
| B, T = idx.size() |
| T0 = 0 |
| cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] |
| x = self.transformer.wte(idx) |
| x = x.float() |
| x = _rms_norm(x) |
| for block in self.transformer.h: |
| x = block(x, cos_sin, None) |
| x = _rms_norm(x) |
|
|
| softcap = 15 |
| logits = self.lm_head(x) |
| logits = softcap * torch.tanh(logits / softcap) |
| loss = None |
| if labels is not None: |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-1, reduction='mean') |
| return {"loss": loss, "logits": logits} |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
| config = kwargs.pop("config", None) |
| subfolder = kwargs.pop("subfolder", None) |
| device_map = kwargs.get("device_map") |
| if device_map is not None: |
| |
| if subfolder is not None: |
| kwargs["subfolder"] = subfolder |
| if config is not None: |
| kwargs["config"] = config |
| return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
|
|
| base_path = Path(pretrained_model_name_or_path) |
| if subfolder: |
| base_path = base_path / subfolder |
|
|
| weight_path = None |
| if base_path.is_dir(): |
| candidate_files = [ |
| base_path / "pytorch_model.bin", |
| base_path / "model.bin", |
| ] |
| candidate_files.extend(sorted(base_path.glob("model_*.pt"), reverse=True)) |
| candidate_files.extend(sorted(base_path.glob("*.bin"), reverse=True)) |
| for cand in candidate_files: |
| if cand.is_file(): |
| weight_path = cand |
| break |
|
|
| if weight_path is None: |
| |
| if subfolder is not None: |
| kwargs["subfolder"] = subfolder |
| if config is not None: |
| kwargs["config"] = config |
| return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
|
|
| if config is None: |
| config = NanoGPTConfig.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder) |
|
|
| torch_dtype = kwargs.pop("torch_dtype", None) |
| strict = kwargs.pop("strict", True) |
|
|
| state_dict = torch.load(str(weight_path), map_location="cpu") |
| if isinstance(state_dict, dict) and "state_dict" in state_dict: |
| state_dict = state_dict["state_dict"] |
| state_dict = {k.lstrip("_orig_mod."): v for k, v in state_dict.items()} |
|
|
| model = cls(config, *model_args) |
| model.load_state_dict(state_dict, strict=strict) |
| if torch_dtype is not None: |
| model = model.to(dtype=torch_dtype) |
| model.eval() |
| return model |
|
|
|
|
|
|
|
|