|
|
| """ |
| NanoMindForFunctionCalling β HuggingFace PreTrainedModel wrapper. |
| |
| Usage after downloading from HF hub: |
| from hf_wrapper import NanoMindForFunctionCalling |
| model = NanoMindForFunctionCalling.from_pretrained("shawneil/NanoMind-MobileActions") |
| # model is ready for inference, no separate architecture file needed |
| """ |
|
|
| import json, math, torch, torch.nn as nn, torch.nn.functional as F |
| from dataclasses import dataclass, asdict |
| from pathlib import Path |
|
|
| try: |
| from transformers import PreTrainedModel, PretrainedConfig |
| HF_AVAILABLE = True |
| except ImportError: |
| HF_AVAILABLE = False |
|
|
|
|
| |
| @dataclass |
| class ModelConfig: |
| vocab_size: int = 50257 |
| d_model: int = 512 |
| n_heads: int = 8 |
| n_kv_heads: int = 2 |
| n_layers: int = 8 |
| max_seq_len: int = 512 |
| ff_mult: int = 4 |
| dropout: float = 0.0 |
| use_moe: bool = False |
| num_experts: int = 4 |
| top_k_experts: int = 2 |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-6): |
| super().__init__(); self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
| def forward(self, x): |
| x32 = x.float() |
| return (x32 * x32.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() |
| ).to(x.dtype).clone() * self.weight |
|
|
| def _freqs_cis(head_dim, max_len, theta=10000.0): |
| freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) |
| t = torch.arange(max_len, device=freqs.device) |
| freqs = torch.outer(t, freqs) |
| return torch.polar(torch.ones_like(freqs), freqs) |
|
|
| def _rope(xq, xk, fc): |
| def rot(x, f): |
| xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) |
| return torch.view_as_real(xc * f[:x.shape[1]].unsqueeze(0).unsqueeze(2) |
| ).flatten(3).to(x.dtype) |
| return rot(xq, fc), rot(xk, fc) |
|
|
| class GQAttn(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.nh = cfg.n_heads; self.nkv = cfg.n_kv_heads |
| self.hd = cfg.d_model // cfg.n_heads |
| self.q = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False) |
| self.k = nn.Linear(cfg.d_model, cfg.n_kv_heads * self.hd, bias=False) |
| self.v = nn.Linear(cfg.d_model, cfg.n_kv_heads * self.hd, bias=False) |
| self.o = nn.Linear(cfg.n_heads * self.hd, cfg.d_model, bias=False) |
| self.drop = cfg.dropout |
| def forward(self, x, fc): |
| B, T, _ = x.shape |
| q = self.q(x).view(B, T, self.nh, self.hd) |
| k = self.k(x).view(B, T, self.nkv, self.hd) |
| v = self.v(x).view(B, T, self.nkv, self.hd) |
| q, k = _rope(q, k, fc) |
| r = self.nh // self.nkv |
| k = k.repeat_interleave(r, 2); v = v.repeat_interleave(r, 2) |
| q, k, v = q.transpose(1,2), k.transpose(1,2), v.transpose(1,2) |
| out = F.scaled_dot_product_attention(q, k, v, None, |
| self.drop if self.training else 0., is_causal=True) |
| return self.o(out.transpose(1,2).contiguous().view(B, T, -1)) |
|
|
| class SwiGLU(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| h = (int(cfg.d_model * cfg.ff_mult * 2 / 3) + 63) // 64 * 64 |
| self.w1 = nn.Linear(cfg.d_model, h, bias=False) |
| self.w2 = nn.Linear(h, cfg.d_model, bias=False) |
| self.w3 = nn.Linear(cfg.d_model, h, bias=False) |
| def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
| class Block(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.an = RMSNorm(cfg.d_model); self.fn = RMSNorm(cfg.d_model) |
| self.attn = GQAttn(cfg); self.ff = SwiGLU(cfg) |
| self.drop = nn.Dropout(cfg.dropout) |
| def forward(self, x, fc): |
| x = x + self.drop(self.attn(self.an(x), fc)) |
| return x + self.drop(self.ff(self.fn(x))) |
|
|
| class _CoreModel(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
| self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model) |
| self.drop = nn.Dropout(cfg.dropout) |
| self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)]) |
| self.norm = RMSNorm(cfg.d_model) |
| self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) |
| self.embed.weight = self.lm_head.weight |
| self.register_buffer("freqs_cis", |
| _freqs_cis(cfg.d_model // cfg.n_heads, cfg.max_seq_len * 2)) |
| def forward(self, idx, targets=None, loss_mask=None): |
| B, T = idx.shape |
| x = self.drop(self.embed(idx)) |
| fc = self.freqs_cis[:T] |
| for blk in self.blocks: x = blk(x, fc) |
| logits = self.lm_head(self.norm(x)) |
| loss = None |
| if targets is not None: |
| fl = logits.view(-1, logits.size(-1)) |
| ft = targets.view(-1) |
| if loss_mask is not None: |
| m = loss_mask.view(-1).bool() |
| fl = fl[m]; ft = ft[m] |
| loss = F.cross_entropy(fl, ft, ignore_index=-1) |
| return logits, loss |
| @torch.no_grad() |
| def generate(self, idx, max_new_tokens=200, temperature=0.8, top_k=50): |
| for _ in range(max_new_tokens): |
| ic = idx[:, -self.cfg.max_seq_len:] |
| logits, _ = self(ic) |
| logits = logits[:, -1, :] / temperature |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = float("-inf") |
| idx = torch.cat([idx, torch.multinomial(F.softmax(logits,-1), 1)], 1) |
| return idx |
|
|
|
|
| |
| if HF_AVAILABLE: |
| class NanoMindConfig(PretrainedConfig): |
| model_type = "nanomind" |
| def __init__(self, vocab_size=50257, d_model=512, n_heads=8, |
| n_kv_heads=2, n_layers=8, max_seq_len=512, |
| ff_mult=4, dropout=0.0, **kwargs): |
| super().__init__(**kwargs) |
| self.vocab_size = vocab_size |
| self.d_model = d_model |
| self.n_heads = n_heads |
| self.n_kv_heads = n_kv_heads |
| self.n_layers = n_layers |
| self.max_seq_len = max_seq_len |
| self.ff_mult = ff_mult |
| self.dropout = dropout |
|
|
| class NanoMindForFunctionCalling(PreTrainedModel): |
| config_class = NanoMindConfig |
|
|
| def __init__(self, config: NanoMindConfig): |
| super().__init__(config) |
| cfg = ModelConfig( |
| vocab_size=config.vocab_size, d_model=config.d_model, |
| n_heads=config.n_heads, n_kv_heads=config.n_kv_heads, |
| n_layers=config.n_layers, max_seq_len=config.max_seq_len, |
| ff_mult=config.ff_mult, dropout=config.dropout, |
| ) |
| self.model = _CoreModel(cfg) |
| self.post_init() |
|
|
| def forward(self, input_ids, labels=None, loss_mask=None, **kwargs): |
| logits, loss = self.model(input_ids, labels, loss_mask) |
| from transformers.modeling_outputs import CausalLMOutput |
| return CausalLMOutput(loss=loss, logits=logits) |
|
|
| @torch.no_grad() |
| def generate_text(self, idx, max_new_tokens=200, temperature=0.8, top_k=50): |
| return self.model.generate(idx, max_new_tokens, temperature, top_k) |
|
|
| @classmethod |
| def from_nanomind_checkpoint(cls, ckpt_path: str): |
| """Load from a raw NanoMind .pt checkpoint (no HF config needed).""" |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| raw_cfg = ckpt.get("config", {}) |
| hf_cfg = NanoMindConfig( |
| vocab_size = raw_cfg.get("vocab_size", 50257), |
| d_model = raw_cfg.get("d_model", 512), |
| n_heads = raw_cfg.get("n_heads", 8), |
| n_kv_heads = raw_cfg.get("n_kv_heads", 2), |
| n_layers = raw_cfg.get("n_layers", 8), |
| max_seq_len = raw_cfg.get("max_seq_len", 512), |
| ff_mult = raw_cfg.get("ff_mult", 4), |
| dropout = raw_cfg.get("dropout", 0.0), |
| ) |
| wrapper = cls(hf_cfg) |
| |
| state = ckpt["model"] |
| |
| if not any(k.startswith("model.") for k in state): |
| state = {"model." + k: v for k, v in state.items()} |
| wrapper.load_state_dict(state, strict=True) |
| return wrapper |
|
|
| else: |
| |
| class NanoMindForFunctionCalling(nn.Module): |
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| self.model = _CoreModel(cfg) |
|
|
| @classmethod |
| def from_checkpoint(cls, ckpt_path): |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| raw_cfg = ckpt.get("config", {}) |
| cfg = ModelConfig(**{k: v for k, v in raw_cfg.items() |
| if k in ModelConfig.__dataclass_fields__}) |
| obj = cls(cfg) |
| state = ckpt["model"] |
| if not any(k.startswith("model.") for k in state): |
| state = {"model." + k: v for k, v in state.items()} |
| obj.load_state_dict(state, strict=True) |
| return obj |
|
|
| def forward(self, idx, targets=None, loss_mask=None): |
| return self.model(idx, targets, loss_mask) |
|
|
| def generate_text(self, idx, **kw): |
| return self.model.generate(idx, **kw) |
|
|