README / modeling_flash.py
NodeNester's picture
Super-squash branch 'main' using huggingface_hub
78c28fc
Raw
History Blame Contribute Delete
13.2 kB
"""
Standalone inference for Nordic Flash 430M (Bifrost family) — pure PyTorch,
depends only on `torch`. Independent re-implementation of the forward pass;
nothing imported from the training stack.
Hybrid architecture (~430M):
* 18 layers in a [dynamic_conv, dynamic_conv, gqa] x6 pattern.
* DynaConv layers: per-token data-dependent causal depthwise conv —
kernel_proj(x) -> (head_dim 80, K 14) softmax-over-taps weights, each kernel
shared by 16 channels; out = out_proj(causal_conv(in_proj(x)) * silu(gate_proj(x))).
* GQA layers: grouped-query attention (16/4 heads, head_dim 80, fused QKV),
partial rotary (first 25% of head dim).
* SwiGLU FFN, RMSNorm, parallel residual (one norm), tied embeddings.
"""
from __future__ import annotations
from dataclasses import dataclass, field
import torch
import torch.nn as nn
import torch.nn.functional as F
def _default_layer_types():
return ["dynamic_conv", "dynamic_conv", "gqa"] * 6
@dataclass
class FlashConfig:
vocab_size: int = 65008
hidden_dim: int = 1280
num_layers: int = 18
num_attention_heads: int = 16
num_kv_heads: int = 4
head_dim: int = 80
ffn_intermediate: int = 3584
rope_theta: float = 500000.0
rope_partial_factor: float = 0.25
rms_norm_eps: float = 1e-6
max_position: int = 4096
dynaconv_kernel: int = 14 # K taps
dynaconv_head_dim: int = 80 # number of distinct per-token kernels
dynaconv_num_heads: int = 16 # channels sharing each kernel (head_dim*num_heads = hidden)
layer_types: list = field(default_factory=_default_layer_types)
@property
def rotary_dim(self) -> int:
rd = max(2, int(self.head_dim * self.rope_partial_factor))
return rd - (rd % 2)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
dtype = x.dtype
x32 = x.float()
x32 = x32 * torch.rsqrt(x32.pow(2).mean(-1, keepdim=True) + self.eps)
return (x32 * self.weight.float()).to(dtype)
def _rotate_half(x):
d = x.shape[-1]
return torch.cat([-x[..., d // 2:], x[..., : d // 2]], dim=-1)
class RoPE(nn.Module):
def __init__(self, head_dim, rotary_dim, theta):
super().__init__()
self.rotary_dim = rotary_dim
idx = torch.arange(0, rotary_dim, 2, dtype=torch.float32)
self.register_buffer("inv_freq", 1.0 / (theta ** (idx / rotary_dim)), persistent=False)
def forward(self, q, k):
rot = self.rotary_dim
t = torch.arange(q.shape[-2], device=q.device, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq.to(q.device))
emb = torch.cat([freqs, freqs], dim=-1)
cos, sin = emb.cos().to(q.dtype)[None, None], emb.sin().to(q.dtype)[None, None]
qr, qp = q[..., :rot], q[..., rot:]
kr, kp = k[..., :rot], k[..., rot:]
qr = qr * cos + _rotate_half(qr) * sin
kr = kr * cos + _rotate_half(kr) * sin
return torch.cat([qr, qp], -1), torch.cat([kr, kp], -1)
class Attention(nn.Module):
def __init__(self, cfg: FlashConfig):
super().__init__()
self.nq, self.nkv, self.hd = cfg.num_attention_heads, cfg.num_kv_heads, cfg.head_dim
self.groups = self.nq // self.nkv
q, kv = self.nq * self.hd, self.nkv * self.hd
self.qkv_proj = nn.Linear(cfg.hidden_dim, q + 2 * kv, bias=False)
self.o_proj = nn.Linear(q, cfg.hidden_dim, bias=False)
self._split = (q, kv, kv)
self.rope = RoPE(self.hd, cfg.rotary_dim, cfg.rope_theta)
def forward(self, x):
b, s, _ = x.shape
qf, kf, vf = self.qkv_proj(x).split(self._split, dim=-1)
q = qf.view(b, s, self.nq, self.hd).transpose(1, 2)
k = kf.view(b, s, self.nkv, self.hd).transpose(1, 2)
v = vf.view(b, s, self.nkv, self.hd).transpose(1, 2)
q, k = self.rope(q, k)
if self.groups > 1:
k = k.repeat_interleave(self.groups, dim=1)
v = v.repeat_interleave(self.groups, dim=1)
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.o_proj(out.transpose(1, 2).reshape(b, s, self.nq * self.hd))
class DynaConv(nn.Module):
"""Per-token data-dependent causal depthwise conv with a silu gate.
kernel_proj(x) -> (head_dim, K) softmax-over-taps weights, each kernel shared
by num_heads channels; causal depthwise conv of in_proj(x); gated by silu(gate).
out = out_proj( causal_conv(in_proj(x), softmax_kernel) * silu(gate_proj(x)) ).
"""
def __init__(self, cfg: FlashConfig):
super().__init__()
D = cfg.hidden_dim
self.K = cfg.dynaconv_kernel
self.hdc = cfg.dynaconv_head_dim # distinct kernels
self.nh = cfg.dynaconv_num_heads # channels per kernel
assert self.hdc * self.nh == D
self.in_proj = nn.Linear(D, D, bias=False)
self.gate_proj = nn.Linear(D, D, bias=False)
self.out_proj = nn.Linear(D, D, bias=False)
self.kernel_proj = nn.Linear(D, self.hdc * self.K, bias=False)
def forward(self, x):
B, S, D = x.shape
K = self.K
k = self.kernel_proj(x).view(B, S, self.hdc, K) # (B,S,head_dim,K)
k = F.softmax(k.float(), dim=-1).to(x.dtype) # softmax over taps
w = k.repeat_interleave(self.nh, dim=2) # (B,S,D,K); channel c -> kernel c//nh
bx = self.in_proj(x)
g = self.gate_proj(x)
xp = F.pad(bx, (0, 0, K - 1, 0)) # left-pad seq by K-1
xu = xp.unfold(1, K, 1) # (B,S,D,K); xu[...,K-1]=bx[t] (current)
y = (xu * w).sum(-1) * F.silu(g) # (B,S,D)
return self.out_proj(y)
class SwiGLU(nn.Module):
def __init__(self, cfg: FlashConfig):
super().__init__()
self.gate_proj = nn.Linear(cfg.hidden_dim, cfg.ffn_intermediate, bias=False)
self.up_proj = nn.Linear(cfg.hidden_dim, cfg.ffn_intermediate, bias=False)
self.down_proj = nn.Linear(cfg.ffn_intermediate, cfg.hidden_dim, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class DecoderBlock(nn.Module):
def __init__(self, cfg: FlashConfig, layer_type: str):
super().__init__()
self.input_norm = RMSNorm(cfg.hidden_dim, cfg.rms_norm_eps)
self.attention = Attention(cfg) if layer_type == "gqa" else DynaConv(cfg)
self.ffn = SwiGLU(cfg)
def forward(self, x):
h = self.input_norm(x)
return x + self.attention(h) + self.ffn(h)
class NordicFlash(nn.Module):
def __init__(self, cfg: FlashConfig):
super().__init__()
self.cfg = cfg
self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.hidden_dim)
self.layers = nn.ModuleList(DecoderBlock(cfg, t) for t in cfg.layer_types)
self.final_norm = RMSNorm(cfg.hidden_dim, cfg.rms_norm_eps)
self.lm_head = nn.Linear(cfg.hidden_dim, cfg.vocab_size, bias=False)
def forward(self, input_ids):
x = self.embed_tokens(input_ids)
for layer in self.layers:
x = layer(x)
return self.lm_head(self.final_norm(x))
@torch.no_grad()
def translate(self, src_ids, tgt_lang_id, *, bos=1, eos=2, eos_src=65007,
max_new=256, vocab_text_limit=65000):
dev = self.embed_tokens.weight.device
ids = [bos, tgt_lang_id] + list(src_ids) + [eos_src]
out = []
for _ in range(max_new):
logits = self(torch.tensor([ids], device=dev)) # full recompute (local conv + attn)
nxt = int(logits[0, -1].argmax())
if nxt == eos:
break
out.append(nxt); ids.append(nxt)
return [t for t in out if t < vocab_text_limit]
@classmethod
def from_checkpoint(cls, path, device="cuda", dtype=torch.bfloat16, cfg=None):
cfg = cfg or FlashConfig()
model = cls(cfg)
if path.endswith(".safetensors"):
from safetensors.torch import load_file
sd = load_file(path)
else:
sd = torch.load(path, map_location="cpu", weights_only=False)
sd = sd.get("model", sd)
sd = {k.replace("._orig_mod", ""): v for k, v in sd.items()}
if "lm_head.weight" not in sd and "embed_tokens.weight" in sd:
sd["lm_head.weight"] = sd["embed_tokens.weight"] # tied
missing, unexpected = model.load_state_dict(sd, strict=False)
missing = [m for m in missing if "inv_freq" not in m]
unexpected = [u for u in unexpected if "inv_freq" not in u]
if missing or unexpected:
raise RuntimeError(f"state_dict mismatch:\n missing={missing}\n unexpected={unexpected}")
return model.to(device=device, dtype=dtype).eval()
# --------------------------------------------------------------------------- #
# Optional HuggingFace wrapper — AutoModelForCausalLM.from_pretrained + #
# model.generate() (trust_remote_code=True). Cache-free recompute for #
# version-robustness; NordicFlash.translate() is the fast standalone path. #
# --------------------------------------------------------------------------- #
try:
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class NordicFlashConfig(PretrainedConfig):
model_type = "nordic_flash"
def __init__(self, vocab_size=65008, hidden_dim=1280, num_layers=18,
num_attention_heads=16, num_kv_heads=4, head_dim=80,
ffn_intermediate=3584, rope_theta=500000.0,
rope_partial_factor=0.25, rms_norm_eps=1e-6, max_position=4096,
dynaconv_kernel=14, dynaconv_head_dim=80, dynaconv_num_heads=16,
layer_types=None, **kwargs):
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.ffn_intermediate = ffn_intermediate
self.rope_theta = rope_theta
self.rope_partial_factor = rope_partial_factor
self.rms_norm_eps = rms_norm_eps
self.max_position = max_position
self.dynaconv_kernel = dynaconv_kernel
self.dynaconv_head_dim = dynaconv_head_dim
self.dynaconv_num_heads = dynaconv_num_heads
self.layer_types = layer_types or _default_layer_types()
self.num_hidden_layers = num_layers
self.hidden_size = hidden_dim
self.num_key_value_heads = num_kv_heads
self.max_position_embeddings = max_position
kwargs.setdefault("tie_word_embeddings", True)
kwargs.setdefault("use_cache", False)
super().__init__(**kwargs)
def to_flash(self):
return FlashConfig(
vocab_size=self.vocab_size, hidden_dim=self.hidden_dim,
num_layers=self.num_layers, num_attention_heads=self.num_attention_heads,
num_kv_heads=self.num_kv_heads, head_dim=self.head_dim,
ffn_intermediate=self.ffn_intermediate, rope_theta=self.rope_theta,
rope_partial_factor=self.rope_partial_factor, rms_norm_eps=self.rms_norm_eps,
max_position=self.max_position, dynaconv_kernel=self.dynaconv_kernel,
dynaconv_head_dim=self.dynaconv_head_dim, dynaconv_num_heads=self.dynaconv_num_heads,
layer_types=list(self.layer_types))
class NordicFlashForCausalLM(PreTrainedModel):
config_class = NordicFlashConfig
_supports_cache_class = False
_no_split_modules = ["DecoderBlock"]
def __init__(self, config: "NordicFlashConfig"):
super().__init__(config)
cfg = config.to_flash()
self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.hidden_dim)
self.layers = nn.ModuleList(DecoderBlock(cfg, t) for t in cfg.layer_types)
self.final_norm = RMSNorm(cfg.hidden_dim, cfg.rms_norm_eps)
self.lm_head = nn.Linear(cfg.hidden_dim, cfg.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self): return self.embed_tokens
def set_input_embeddings(self, v): self.embed_tokens = v
def get_output_embeddings(self): return self.lm_head
def forward(self, input_ids=None, attention_mask=None, past_key_values=None,
use_cache=None, labels=None, return_dict=True, **kwargs):
x = self.embed_tokens(input_ids)
for layer in self.layers:
x = layer(x)
logits = self.lm_head(self.final_norm(x))
return CausalLMOutputWithPast(logits=logits)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids, "use_cache": False}
except ImportError:
pass