NanoMind-MobileActions / hf_wrapper.py
shawneil's picture
Update HF-compatible weights
d33490b verified
"""
NanoMindForFunctionCalling β€” HuggingFace PreTrainedModel wrapper.
Usage after downloading from HF hub:
from hf_wrapper import NanoMindForFunctionCalling
model = NanoMindForFunctionCalling.from_pretrained("shawneil/NanoMind-MobileActions")
# model is ready for inference, no separate architecture file needed
"""
import json, math, torch, torch.nn as nn, torch.nn.functional as F
from dataclasses import dataclass, asdict
from pathlib import Path
try:
from transformers import PreTrainedModel, PretrainedConfig
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
# ── Minimal standalone architecture ──────────────────────────────────
@dataclass
class ModelConfig:
vocab_size: int = 50257
d_model: int = 512
n_heads: int = 8
n_kv_heads: int = 2
n_layers: int = 8
max_seq_len: int = 512
ff_mult: int = 4
dropout: float = 0.0
use_moe: bool = False
num_experts: int = 4
top_k_experts: int = 2
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__(); self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
x32 = x.float()
return (x32 * x32.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
).to(x.dtype).clone() * self.weight
def _freqs_cis(head_dim, max_len, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(max_len, device=freqs.device)
freqs = torch.outer(t, freqs)
return torch.polar(torch.ones_like(freqs), freqs)
def _rope(xq, xk, fc):
def rot(x, f):
xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
return torch.view_as_real(xc * f[:x.shape[1]].unsqueeze(0).unsqueeze(2)
).flatten(3).to(x.dtype)
return rot(xq, fc), rot(xk, fc)
class GQAttn(nn.Module):
def __init__(self, cfg):
super().__init__()
self.nh = cfg.n_heads; self.nkv = cfg.n_kv_heads
self.hd = cfg.d_model // cfg.n_heads
self.q = nn.Linear(cfg.d_model, cfg.n_heads * self.hd, bias=False)
self.k = nn.Linear(cfg.d_model, cfg.n_kv_heads * self.hd, bias=False)
self.v = nn.Linear(cfg.d_model, cfg.n_kv_heads * self.hd, bias=False)
self.o = nn.Linear(cfg.n_heads * self.hd, cfg.d_model, bias=False)
self.drop = cfg.dropout
def forward(self, x, fc):
B, T, _ = x.shape
q = self.q(x).view(B, T, self.nh, self.hd)
k = self.k(x).view(B, T, self.nkv, self.hd)
v = self.v(x).view(B, T, self.nkv, self.hd)
q, k = _rope(q, k, fc)
r = self.nh // self.nkv
k = k.repeat_interleave(r, 2); v = v.repeat_interleave(r, 2)
q, k, v = q.transpose(1,2), k.transpose(1,2), v.transpose(1,2)
out = F.scaled_dot_product_attention(q, k, v, None,
self.drop if self.training else 0., is_causal=True)
return self.o(out.transpose(1,2).contiguous().view(B, T, -1))
class SwiGLU(nn.Module):
def __init__(self, cfg):
super().__init__()
h = (int(cfg.d_model * cfg.ff_mult * 2 / 3) + 63) // 64 * 64
self.w1 = nn.Linear(cfg.d_model, h, bias=False)
self.w2 = nn.Linear(h, cfg.d_model, bias=False)
self.w3 = nn.Linear(cfg.d_model, h, bias=False)
def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x))
class Block(nn.Module):
def __init__(self, cfg):
super().__init__()
self.an = RMSNorm(cfg.d_model); self.fn = RMSNorm(cfg.d_model)
self.attn = GQAttn(cfg); self.ff = SwiGLU(cfg)
self.drop = nn.Dropout(cfg.dropout)
def forward(self, x, fc):
x = x + self.drop(self.attn(self.an(x), fc))
return x + self.drop(self.ff(self.fn(x)))
class _CoreModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.drop = nn.Dropout(cfg.dropout)
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
self.norm = RMSNorm(cfg.d_model)
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
self.embed.weight = self.lm_head.weight
self.register_buffer("freqs_cis",
_freqs_cis(cfg.d_model // cfg.n_heads, cfg.max_seq_len * 2))
def forward(self, idx, targets=None, loss_mask=None):
B, T = idx.shape
x = self.drop(self.embed(idx))
fc = self.freqs_cis[:T]
for blk in self.blocks: x = blk(x, fc)
logits = self.lm_head(self.norm(x))
loss = None
if targets is not None:
fl = logits.view(-1, logits.size(-1))
ft = targets.view(-1)
if loss_mask is not None:
m = loss_mask.view(-1).bool()
fl = fl[m]; ft = ft[m]
loss = F.cross_entropy(fl, ft, ignore_index=-1)
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens=200, temperature=0.8, top_k=50):
for _ in range(max_new_tokens):
ic = idx[:, -self.cfg.max_seq_len:]
logits, _ = self(ic)
logits = logits[:, -1, :] / temperature
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float("-inf")
idx = torch.cat([idx, torch.multinomial(F.softmax(logits,-1), 1)], 1)
return idx
# ── HF-compatible wrapper ─────────────────────────────────────────────
if HF_AVAILABLE:
class NanoMindConfig(PretrainedConfig):
model_type = "nanomind"
def __init__(self, vocab_size=50257, d_model=512, n_heads=8,
n_kv_heads=2, n_layers=8, max_seq_len=512,
ff_mult=4, dropout=0.0, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.ff_mult = ff_mult
self.dropout = dropout
class NanoMindForFunctionCalling(PreTrainedModel):
config_class = NanoMindConfig
def __init__(self, config: NanoMindConfig):
super().__init__(config)
cfg = ModelConfig(
vocab_size=config.vocab_size, d_model=config.d_model,
n_heads=config.n_heads, n_kv_heads=config.n_kv_heads,
n_layers=config.n_layers, max_seq_len=config.max_seq_len,
ff_mult=config.ff_mult, dropout=config.dropout,
)
self.model = _CoreModel(cfg)
self.post_init()
def forward(self, input_ids, labels=None, loss_mask=None, **kwargs):
logits, loss = self.model(input_ids, labels, loss_mask)
from transformers.modeling_outputs import CausalLMOutput
return CausalLMOutput(loss=loss, logits=logits)
@torch.no_grad()
def generate_text(self, idx, max_new_tokens=200, temperature=0.8, top_k=50):
return self.model.generate(idx, max_new_tokens, temperature, top_k)
@classmethod
def from_nanomind_checkpoint(cls, ckpt_path: str):
"""Load from a raw NanoMind .pt checkpoint (no HF config needed)."""
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
raw_cfg = ckpt.get("config", {})
hf_cfg = NanoMindConfig(
vocab_size = raw_cfg.get("vocab_size", 50257),
d_model = raw_cfg.get("d_model", 512),
n_heads = raw_cfg.get("n_heads", 8),
n_kv_heads = raw_cfg.get("n_kv_heads", 2),
n_layers = raw_cfg.get("n_layers", 8),
max_seq_len = raw_cfg.get("max_seq_len", 512),
ff_mult = raw_cfg.get("ff_mult", 4),
dropout = raw_cfg.get("dropout", 0.0),
)
wrapper = cls(hf_cfg)
# remap keys: model.xxx β†’ model.xxx (already correct)
state = ckpt["model"]
# if saved without wrapper prefix, add it
if not any(k.startswith("model.") for k in state):
state = {"model." + k: v for k, v in state.items()}
wrapper.load_state_dict(state, strict=True)
return wrapper
else:
# Fallback when transformers not installed
class NanoMindForFunctionCalling(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.model = _CoreModel(cfg)
@classmethod
def from_checkpoint(cls, ckpt_path):
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
raw_cfg = ckpt.get("config", {})
cfg = ModelConfig(**{k: v for k, v in raw_cfg.items()
if k in ModelConfig.__dataclass_fields__})
obj = cls(cfg)
state = ckpt["model"]
if not any(k.startswith("model.") for k in state):
state = {"model." + k: v for k, v in state.items()}
obj.load_state_dict(state, strict=True)
return obj
def forward(self, idx, targets=None, loss_mask=None):
return self.model(idx, targets, loss_mask)
def generate_text(self, idx, **kw):
return self.model.generate(idx, **kw)