| """ |
| Hugging Face model class for MINDI 1.0 420M. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
| from .configuration_mindi import MindiConfig |
|
|
|
|
| @dataclass |
| class _Cfg: |
| vocab_size: int |
| max_seq_len: int |
| d_model: int |
| n_layers: int |
| n_heads: int |
| d_ff: int |
| dropout: float |
| tie_embeddings: bool |
| init_std: float |
| rms_norm_eps: float |
|
|
| @property |
| def head_dim(self) -> int: |
| if self.d_model % self.n_heads != 0: |
| raise ValueError("d_model must be divisible by n_heads") |
| return self.d_model // self.n_heads |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-5) -> None: |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| norm = x.pow(2).mean(dim=-1, keepdim=True) |
| x = x * torch.rsqrt(norm + self.eps) |
| return self.weight * x |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| def __init__(self, head_dim: int, max_seq_len: int) -> None: |
| super().__init__() |
| if head_dim % 2 != 0: |
| raise ValueError("head_dim must be even for rotary embeddings") |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim)) |
| t = torch.arange(max_seq_len, dtype=torch.float32) |
| freqs = torch.outer(t, inv_freq) |
| self.register_buffer("cos_cached", torch.cos(freqs), persistent=False) |
| self.register_buffer("sin_cached", torch.sin(freqs), persistent=False) |
|
|
| def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0) |
| sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0) |
| return self._apply_rotary(q, cos, sin), self._apply_rotary(k, cos, sin) |
|
|
| @staticmethod |
| def _apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| x1 = x[..., ::2] |
| x2 = x[..., 1::2] |
| xe = x1 * cos - x2 * sin |
| xo = x1 * sin + x2 * cos |
| return torch.stack((xe, xo), dim=-1).flatten(-2) |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, cfg: _Cfg) -> None: |
| super().__init__() |
| self.n_heads = cfg.n_heads |
| self.head_dim = cfg.head_dim |
| self.scale = self.head_dim ** -0.5 |
| self.q_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False) |
| self.k_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False) |
| self.v_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False) |
| self.o_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False) |
| self.dropout = nn.Dropout(cfg.dropout) |
| self.rotary = RotaryEmbedding(self.head_dim, cfg.max_seq_len) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| bsz, seq_len, _ = x.shape |
| q = self.q_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) |
| q, k = self.rotary(q, k, seq_len=seq_len) |
| out = F.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| attn_mask=None, |
| dropout_p=self.dropout.p if self.training else 0.0, |
| is_causal=True, |
| scale=self.scale, |
| ) |
| out = out.transpose(1, 2).contiguous().view(bsz, seq_len, -1) |
| return self.o_proj(out) |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, cfg: _Cfg) -> None: |
| super().__init__() |
| self.fc1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) |
| self.fc2 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False) |
| self.dropout = nn.Dropout(cfg.dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.fc1(x) |
| x = F.gelu(x, approximate="tanh") |
| x = self.fc2(x) |
| x = self.dropout(x) |
| return x |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, cfg: _Cfg) -> None: |
| super().__init__() |
| self.norm1 = RMSNorm(cfg.d_model, cfg.rms_norm_eps) |
| self.attn = CausalSelfAttention(cfg) |
| self.norm2 = RMSNorm(cfg.d_model, cfg.rms_norm_eps) |
| self.ffn = FeedForward(cfg) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = x + self.attn(self.norm1(x)) |
| x = x + self.ffn(self.norm2(x)) |
| return x |
|
|
|
|
| class MindiForCausalLM(PreTrainedModel): |
| config_class = MindiConfig |
| base_model_prefix = "mindi" |
| supports_gradient_checkpointing = False |
|
|
| def __init__(self, config: MindiConfig): |
| super().__init__(config) |
| cfg = _Cfg( |
| vocab_size=config.vocab_size, |
| max_seq_len=config.max_seq_len, |
| d_model=config.d_model, |
| n_layers=config.n_layers, |
| n_heads=config.n_heads, |
| d_ff=config.d_ff, |
| dropout=config.dropout, |
| tie_embeddings=config.tie_embeddings, |
| init_std=config.init_std, |
| rms_norm_eps=config.rms_norm_eps, |
| ) |
|
|
| self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.d_model) |
| self.dropout = nn.Dropout(cfg.dropout) |
| self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)]) |
| self.norm_final = RMSNorm(cfg.d_model, cfg.rms_norm_eps) |
| self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) |
|
|
| if cfg.tie_embeddings: |
| self.lm_head.weight = self.embed_tokens.weight |
|
|
| self.post_init() |
|
|
| def _init_weights(self, module: nn.Module) -> None: |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) |
|
|
| def get_input_embeddings(self) -> nn.Module: |
| return self.embed_tokens |
|
|
| def set_input_embeddings(self, value: nn.Module) -> None: |
| self.embed_tokens = value |
|
|
| def get_output_embeddings(self) -> nn.Module: |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings: nn.Module) -> None: |
| self.lm_head = new_embeddings |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> CausalLMOutputWithPast: |
| del attention_mask, kwargs |
|
|
| x = self.embed_tokens(input_ids) |
| x = self.dropout(x) |
|
|
| for block in self.blocks: |
| x = block(x) |
|
|
| x = self.norm_final(x) |
| logits = self.lm_head(x) |
|
|
| loss = None |
| if labels is not None: |
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = labels[:, 1:].contiguous() |
| loss = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=-100, |
| ) |
|
|
| return CausalLMOutputWithPast(loss=loss, logits=logits) |
|
|
| @torch.no_grad() |
| def prepare_inputs_for_generation(self, input_ids: torch.Tensor, **kwargs): |
| del kwargs |
| return {"input_ids": input_ids} |
|
|