""" Standalone inference for Nordic Flash 430M (Bifrost family) — pure PyTorch, depends only on `torch`. Independent re-implementation of the forward pass; nothing imported from the training stack. Hybrid architecture (~430M): * 18 layers in a [dynamic_conv, dynamic_conv, gqa] x6 pattern. * DynaConv layers: per-token data-dependent causal depthwise conv — kernel_proj(x) -> (head_dim 80, K 14) softmax-over-taps weights, each kernel shared by 16 channels; out = out_proj(causal_conv(in_proj(x)) * silu(gate_proj(x))). * GQA layers: grouped-query attention (16/4 heads, head_dim 80, fused QKV), partial rotary (first 25% of head dim). * SwiGLU FFN, RMSNorm, parallel residual (one norm), tied embeddings. """ from __future__ import annotations from dataclasses import dataclass, field import torch import torch.nn as nn import torch.nn.functional as F def _default_layer_types(): return ["dynamic_conv", "dynamic_conv", "gqa"] * 6 @dataclass class FlashConfig: vocab_size: int = 65008 hidden_dim: int = 1280 num_layers: int = 18 num_attention_heads: int = 16 num_kv_heads: int = 4 head_dim: int = 80 ffn_intermediate: int = 3584 rope_theta: float = 500000.0 rope_partial_factor: float = 0.25 rms_norm_eps: float = 1e-6 max_position: int = 4096 dynaconv_kernel: int = 14 # K taps dynaconv_head_dim: int = 80 # number of distinct per-token kernels dynaconv_num_heads: int = 16 # channels sharing each kernel (head_dim*num_heads = hidden) layer_types: list = field(default_factory=_default_layer_types) @property def rotary_dim(self) -> int: rd = max(2, int(self.head_dim * self.rope_partial_factor)) return rd - (rd % 2) class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6) -> None: super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): dtype = x.dtype x32 = x.float() x32 = x32 * torch.rsqrt(x32.pow(2).mean(-1, keepdim=True) + self.eps) return (x32 * self.weight.float()).to(dtype) def _rotate_half(x): d = x.shape[-1] return torch.cat([-x[..., d // 2:], x[..., : d // 2]], dim=-1) class RoPE(nn.Module): def __init__(self, head_dim, rotary_dim, theta): super().__init__() self.rotary_dim = rotary_dim idx = torch.arange(0, rotary_dim, 2, dtype=torch.float32) self.register_buffer("inv_freq", 1.0 / (theta ** (idx / rotary_dim)), persistent=False) def forward(self, q, k): rot = self.rotary_dim t = torch.arange(q.shape[-2], device=q.device, dtype=torch.float32) freqs = torch.outer(t, self.inv_freq.to(q.device)) emb = torch.cat([freqs, freqs], dim=-1) cos, sin = emb.cos().to(q.dtype)[None, None], emb.sin().to(q.dtype)[None, None] qr, qp = q[..., :rot], q[..., rot:] kr, kp = k[..., :rot], k[..., rot:] qr = qr * cos + _rotate_half(qr) * sin kr = kr * cos + _rotate_half(kr) * sin return torch.cat([qr, qp], -1), torch.cat([kr, kp], -1) class Attention(nn.Module): def __init__(self, cfg: FlashConfig): super().__init__() self.nq, self.nkv, self.hd = cfg.num_attention_heads, cfg.num_kv_heads, cfg.head_dim self.groups = self.nq // self.nkv q, kv = self.nq * self.hd, self.nkv * self.hd self.qkv_proj = nn.Linear(cfg.hidden_dim, q + 2 * kv, bias=False) self.o_proj = nn.Linear(q, cfg.hidden_dim, bias=False) self._split = (q, kv, kv) self.rope = RoPE(self.hd, cfg.rotary_dim, cfg.rope_theta) def forward(self, x): b, s, _ = x.shape qf, kf, vf = self.qkv_proj(x).split(self._split, dim=-1) q = qf.view(b, s, self.nq, self.hd).transpose(1, 2) k = kf.view(b, s, self.nkv, self.hd).transpose(1, 2) v = vf.view(b, s, self.nkv, self.hd).transpose(1, 2) q, k = self.rope(q, k) if self.groups > 1: k = k.repeat_interleave(self.groups, dim=1) v = v.repeat_interleave(self.groups, dim=1) out = F.scaled_dot_product_attention(q, k, v, is_causal=True) return self.o_proj(out.transpose(1, 2).reshape(b, s, self.nq * self.hd)) class DynaConv(nn.Module): """Per-token data-dependent causal depthwise conv with a silu gate. kernel_proj(x) -> (head_dim, K) softmax-over-taps weights, each kernel shared by num_heads channels; causal depthwise conv of in_proj(x); gated by silu(gate). out = out_proj( causal_conv(in_proj(x), softmax_kernel) * silu(gate_proj(x)) ). """ def __init__(self, cfg: FlashConfig): super().__init__() D = cfg.hidden_dim self.K = cfg.dynaconv_kernel self.hdc = cfg.dynaconv_head_dim # distinct kernels self.nh = cfg.dynaconv_num_heads # channels per kernel assert self.hdc * self.nh == D self.in_proj = nn.Linear(D, D, bias=False) self.gate_proj = nn.Linear(D, D, bias=False) self.out_proj = nn.Linear(D, D, bias=False) self.kernel_proj = nn.Linear(D, self.hdc * self.K, bias=False) def forward(self, x): B, S, D = x.shape K = self.K k = self.kernel_proj(x).view(B, S, self.hdc, K) # (B,S,head_dim,K) k = F.softmax(k.float(), dim=-1).to(x.dtype) # softmax over taps w = k.repeat_interleave(self.nh, dim=2) # (B,S,D,K); channel c -> kernel c//nh bx = self.in_proj(x) g = self.gate_proj(x) xp = F.pad(bx, (0, 0, K - 1, 0)) # left-pad seq by K-1 xu = xp.unfold(1, K, 1) # (B,S,D,K); xu[...,K-1]=bx[t] (current) y = (xu * w).sum(-1) * F.silu(g) # (B,S,D) return self.out_proj(y) class SwiGLU(nn.Module): def __init__(self, cfg: FlashConfig): super().__init__() self.gate_proj = nn.Linear(cfg.hidden_dim, cfg.ffn_intermediate, bias=False) self.up_proj = nn.Linear(cfg.hidden_dim, cfg.ffn_intermediate, bias=False) self.down_proj = nn.Linear(cfg.ffn_intermediate, cfg.hidden_dim, bias=False) def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class DecoderBlock(nn.Module): def __init__(self, cfg: FlashConfig, layer_type: str): super().__init__() self.input_norm = RMSNorm(cfg.hidden_dim, cfg.rms_norm_eps) self.attention = Attention(cfg) if layer_type == "gqa" else DynaConv(cfg) self.ffn = SwiGLU(cfg) def forward(self, x): h = self.input_norm(x) return x + self.attention(h) + self.ffn(h) class NordicFlash(nn.Module): def __init__(self, cfg: FlashConfig): super().__init__() self.cfg = cfg self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.hidden_dim) self.layers = nn.ModuleList(DecoderBlock(cfg, t) for t in cfg.layer_types) self.final_norm = RMSNorm(cfg.hidden_dim, cfg.rms_norm_eps) self.lm_head = nn.Linear(cfg.hidden_dim, cfg.vocab_size, bias=False) def forward(self, input_ids): x = self.embed_tokens(input_ids) for layer in self.layers: x = layer(x) return self.lm_head(self.final_norm(x)) @torch.no_grad() def translate(self, src_ids, tgt_lang_id, *, bos=1, eos=2, eos_src=65007, max_new=256, vocab_text_limit=65000): dev = self.embed_tokens.weight.device ids = [bos, tgt_lang_id] + list(src_ids) + [eos_src] out = [] for _ in range(max_new): logits = self(torch.tensor([ids], device=dev)) # full recompute (local conv + attn) nxt = int(logits[0, -1].argmax()) if nxt == eos: break out.append(nxt); ids.append(nxt) return [t for t in out if t < vocab_text_limit] @classmethod def from_checkpoint(cls, path, device="cuda", dtype=torch.bfloat16, cfg=None): cfg = cfg or FlashConfig() model = cls(cfg) if path.endswith(".safetensors"): from safetensors.torch import load_file sd = load_file(path) else: sd = torch.load(path, map_location="cpu", weights_only=False) sd = sd.get("model", sd) sd = {k.replace("._orig_mod", ""): v for k, v in sd.items()} if "lm_head.weight" not in sd and "embed_tokens.weight" in sd: sd["lm_head.weight"] = sd["embed_tokens.weight"] # tied missing, unexpected = model.load_state_dict(sd, strict=False) missing = [m for m in missing if "inv_freq" not in m] unexpected = [u for u in unexpected if "inv_freq" not in u] if missing or unexpected: raise RuntimeError(f"state_dict mismatch:\n missing={missing}\n unexpected={unexpected}") return model.to(device=device, dtype=dtype).eval() # --------------------------------------------------------------------------- # # Optional HuggingFace wrapper — AutoModelForCausalLM.from_pretrained + # # model.generate() (trust_remote_code=True). Cache-free recompute for # # version-robustness; NordicFlash.translate() is the fast standalone path. # # --------------------------------------------------------------------------- # try: from transformers import PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast class NordicFlashConfig(PretrainedConfig): model_type = "nordic_flash" def __init__(self, vocab_size=65008, hidden_dim=1280, num_layers=18, num_attention_heads=16, num_kv_heads=4, head_dim=80, ffn_intermediate=3584, rope_theta=500000.0, rope_partial_factor=0.25, rms_norm_eps=1e-6, max_position=4096, dynaconv_kernel=14, dynaconv_head_dim=80, dynaconv_num_heads=16, layer_types=None, **kwargs): self.vocab_size = vocab_size self.hidden_dim = hidden_dim self.num_layers = num_layers self.num_attention_heads = num_attention_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.ffn_intermediate = ffn_intermediate self.rope_theta = rope_theta self.rope_partial_factor = rope_partial_factor self.rms_norm_eps = rms_norm_eps self.max_position = max_position self.dynaconv_kernel = dynaconv_kernel self.dynaconv_head_dim = dynaconv_head_dim self.dynaconv_num_heads = dynaconv_num_heads self.layer_types = layer_types or _default_layer_types() self.num_hidden_layers = num_layers self.hidden_size = hidden_dim self.num_key_value_heads = num_kv_heads self.max_position_embeddings = max_position kwargs.setdefault("tie_word_embeddings", True) kwargs.setdefault("use_cache", False) super().__init__(**kwargs) def to_flash(self): return FlashConfig( vocab_size=self.vocab_size, hidden_dim=self.hidden_dim, num_layers=self.num_layers, num_attention_heads=self.num_attention_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim, ffn_intermediate=self.ffn_intermediate, rope_theta=self.rope_theta, rope_partial_factor=self.rope_partial_factor, rms_norm_eps=self.rms_norm_eps, max_position=self.max_position, dynaconv_kernel=self.dynaconv_kernel, dynaconv_head_dim=self.dynaconv_head_dim, dynaconv_num_heads=self.dynaconv_num_heads, layer_types=list(self.layer_types)) class NordicFlashForCausalLM(PreTrainedModel): config_class = NordicFlashConfig _supports_cache_class = False _no_split_modules = ["DecoderBlock"] def __init__(self, config: "NordicFlashConfig"): super().__init__(config) cfg = config.to_flash() self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.hidden_dim) self.layers = nn.ModuleList(DecoderBlock(cfg, t) for t in cfg.layer_types) self.final_norm = RMSNorm(cfg.hidden_dim, cfg.rms_norm_eps) self.lm_head = nn.Linear(cfg.hidden_dim, cfg.vocab_size, bias=False) self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, v): self.embed_tokens = v def get_output_embeddings(self): return self.lm_head def forward(self, input_ids=None, attention_mask=None, past_key_values=None, use_cache=None, labels=None, return_dict=True, **kwargs): x = self.embed_tokens(input_ids) for layer in self.layers: x = layer(x) logits = self.lm_head(self.final_norm(x)) return CausalLMOutputWithPast(logits=logits) def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids, "use_cache": False} except ImportError: pass