Spaces:
Sleeping
Sleeping
| """ | |
| Standalone inference for Nordic Flash 430M (Bifrost family) — pure PyTorch, | |
| depends only on `torch`. Independent re-implementation of the forward pass; | |
| nothing imported from the training stack. | |
| Hybrid architecture (~430M): | |
| * 18 layers in a [dynamic_conv, dynamic_conv, gqa] x6 pattern. | |
| * DynaConv layers: per-token data-dependent causal depthwise conv — | |
| kernel_proj(x) -> (head_dim 80, K 14) softmax-over-taps weights, each kernel | |
| shared by 16 channels; out = out_proj(causal_conv(in_proj(x)) * silu(gate_proj(x))). | |
| * GQA layers: grouped-query attention (16/4 heads, head_dim 80, fused QKV), | |
| partial rotary (first 25% of head dim). | |
| * SwiGLU FFN, RMSNorm, parallel residual (one norm), tied embeddings. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def _default_layer_types(): | |
| return ["dynamic_conv", "dynamic_conv", "gqa"] * 6 | |
| class FlashConfig: | |
| vocab_size: int = 65008 | |
| hidden_dim: int = 1280 | |
| num_layers: int = 18 | |
| num_attention_heads: int = 16 | |
| num_kv_heads: int = 4 | |
| head_dim: int = 80 | |
| ffn_intermediate: int = 3584 | |
| rope_theta: float = 500000.0 | |
| rope_partial_factor: float = 0.25 | |
| rms_norm_eps: float = 1e-6 | |
| max_position: int = 4096 | |
| dynaconv_kernel: int = 14 # K taps | |
| dynaconv_head_dim: int = 80 # number of distinct per-token kernels | |
| dynaconv_num_heads: int = 16 # channels sharing each kernel (head_dim*num_heads = hidden) | |
| layer_types: list = field(default_factory=_default_layer_types) | |
| def rotary_dim(self) -> int: | |
| rd = max(2, int(self.head_dim * self.rope_partial_factor)) | |
| return rd - (rd % 2) | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6) -> None: | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def forward(self, x): | |
| dtype = x.dtype | |
| x32 = x.float() | |
| x32 = x32 * torch.rsqrt(x32.pow(2).mean(-1, keepdim=True) + self.eps) | |
| return (x32 * self.weight.float()).to(dtype) | |
| def _rotate_half(x): | |
| d = x.shape[-1] | |
| return torch.cat([-x[..., d // 2:], x[..., : d // 2]], dim=-1) | |
| class RoPE(nn.Module): | |
| def __init__(self, head_dim, rotary_dim, theta): | |
| super().__init__() | |
| self.rotary_dim = rotary_dim | |
| idx = torch.arange(0, rotary_dim, 2, dtype=torch.float32) | |
| self.register_buffer("inv_freq", 1.0 / (theta ** (idx / rotary_dim)), persistent=False) | |
| def forward(self, q, k): | |
| rot = self.rotary_dim | |
| t = torch.arange(q.shape[-2], device=q.device, dtype=torch.float32) | |
| freqs = torch.outer(t, self.inv_freq.to(q.device)) | |
| emb = torch.cat([freqs, freqs], dim=-1) | |
| cos, sin = emb.cos().to(q.dtype)[None, None], emb.sin().to(q.dtype)[None, None] | |
| qr, qp = q[..., :rot], q[..., rot:] | |
| kr, kp = k[..., :rot], k[..., rot:] | |
| qr = qr * cos + _rotate_half(qr) * sin | |
| kr = kr * cos + _rotate_half(kr) * sin | |
| return torch.cat([qr, qp], -1), torch.cat([kr, kp], -1) | |
| class Attention(nn.Module): | |
| def __init__(self, cfg: FlashConfig): | |
| super().__init__() | |
| self.nq, self.nkv, self.hd = cfg.num_attention_heads, cfg.num_kv_heads, cfg.head_dim | |
| self.groups = self.nq // self.nkv | |
| q, kv = self.nq * self.hd, self.nkv * self.hd | |
| self.qkv_proj = nn.Linear(cfg.hidden_dim, q + 2 * kv, bias=False) | |
| self.o_proj = nn.Linear(q, cfg.hidden_dim, bias=False) | |
| self._split = (q, kv, kv) | |
| self.rope = RoPE(self.hd, cfg.rotary_dim, cfg.rope_theta) | |
| def forward(self, x): | |
| b, s, _ = x.shape | |
| qf, kf, vf = self.qkv_proj(x).split(self._split, dim=-1) | |
| q = qf.view(b, s, self.nq, self.hd).transpose(1, 2) | |
| k = kf.view(b, s, self.nkv, self.hd).transpose(1, 2) | |
| v = vf.view(b, s, self.nkv, self.hd).transpose(1, 2) | |
| q, k = self.rope(q, k) | |
| if self.groups > 1: | |
| k = k.repeat_interleave(self.groups, dim=1) | |
| v = v.repeat_interleave(self.groups, dim=1) | |
| out = F.scaled_dot_product_attention(q, k, v, is_causal=True) | |
| return self.o_proj(out.transpose(1, 2).reshape(b, s, self.nq * self.hd)) | |
| class DynaConv(nn.Module): | |
| """Per-token data-dependent causal depthwise conv with a silu gate. | |
| kernel_proj(x) -> (head_dim, K) softmax-over-taps weights, each kernel shared | |
| by num_heads channels; causal depthwise conv of in_proj(x); gated by silu(gate). | |
| out = out_proj( causal_conv(in_proj(x), softmax_kernel) * silu(gate_proj(x)) ). | |
| """ | |
| def __init__(self, cfg: FlashConfig): | |
| super().__init__() | |
| D = cfg.hidden_dim | |
| self.K = cfg.dynaconv_kernel | |
| self.hdc = cfg.dynaconv_head_dim # distinct kernels | |
| self.nh = cfg.dynaconv_num_heads # channels per kernel | |
| assert self.hdc * self.nh == D | |
| self.in_proj = nn.Linear(D, D, bias=False) | |
| self.gate_proj = nn.Linear(D, D, bias=False) | |
| self.out_proj = nn.Linear(D, D, bias=False) | |
| self.kernel_proj = nn.Linear(D, self.hdc * self.K, bias=False) | |
| def forward(self, x): | |
| B, S, D = x.shape | |
| K = self.K | |
| k = self.kernel_proj(x).view(B, S, self.hdc, K) # (B,S,head_dim,K) | |
| k = F.softmax(k.float(), dim=-1).to(x.dtype) # softmax over taps | |
| w = k.repeat_interleave(self.nh, dim=2) # (B,S,D,K); channel c -> kernel c//nh | |
| bx = self.in_proj(x) | |
| g = self.gate_proj(x) | |
| xp = F.pad(bx, (0, 0, K - 1, 0)) # left-pad seq by K-1 | |
| xu = xp.unfold(1, K, 1) # (B,S,D,K); xu[...,K-1]=bx[t] (current) | |
| y = (xu * w).sum(-1) * F.silu(g) # (B,S,D) | |
| return self.out_proj(y) | |
| class SwiGLU(nn.Module): | |
| def __init__(self, cfg: FlashConfig): | |
| super().__init__() | |
| self.gate_proj = nn.Linear(cfg.hidden_dim, cfg.ffn_intermediate, bias=False) | |
| self.up_proj = nn.Linear(cfg.hidden_dim, cfg.ffn_intermediate, bias=False) | |
| self.down_proj = nn.Linear(cfg.ffn_intermediate, cfg.hidden_dim, bias=False) | |
| def forward(self, x): | |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) | |
| class DecoderBlock(nn.Module): | |
| def __init__(self, cfg: FlashConfig, layer_type: str): | |
| super().__init__() | |
| self.input_norm = RMSNorm(cfg.hidden_dim, cfg.rms_norm_eps) | |
| self.attention = Attention(cfg) if layer_type == "gqa" else DynaConv(cfg) | |
| self.ffn = SwiGLU(cfg) | |
| def forward(self, x): | |
| h = self.input_norm(x) | |
| return x + self.attention(h) + self.ffn(h) | |
| class NordicFlash(nn.Module): | |
| def __init__(self, cfg: FlashConfig): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.hidden_dim) | |
| self.layers = nn.ModuleList(DecoderBlock(cfg, t) for t in cfg.layer_types) | |
| self.final_norm = RMSNorm(cfg.hidden_dim, cfg.rms_norm_eps) | |
| self.lm_head = nn.Linear(cfg.hidden_dim, cfg.vocab_size, bias=False) | |
| def forward(self, input_ids): | |
| x = self.embed_tokens(input_ids) | |
| for layer in self.layers: | |
| x = layer(x) | |
| return self.lm_head(self.final_norm(x)) | |
| def translate(self, src_ids, tgt_lang_id, *, bos=1, eos=2, eos_src=65007, | |
| max_new=256, vocab_text_limit=65000): | |
| dev = self.embed_tokens.weight.device | |
| ids = [bos, tgt_lang_id] + list(src_ids) + [eos_src] | |
| out = [] | |
| for _ in range(max_new): | |
| logits = self(torch.tensor([ids], device=dev)) # full recompute (local conv + attn) | |
| nxt = int(logits[0, -1].argmax()) | |
| if nxt == eos: | |
| break | |
| out.append(nxt); ids.append(nxt) | |
| return [t for t in out if t < vocab_text_limit] | |
| def from_checkpoint(cls, path, device="cuda", dtype=torch.bfloat16, cfg=None): | |
| cfg = cfg or FlashConfig() | |
| model = cls(cfg) | |
| if path.endswith(".safetensors"): | |
| from safetensors.torch import load_file | |
| sd = load_file(path) | |
| else: | |
| sd = torch.load(path, map_location="cpu", weights_only=False) | |
| sd = sd.get("model", sd) | |
| sd = {k.replace("._orig_mod", ""): v for k, v in sd.items()} | |
| if "lm_head.weight" not in sd and "embed_tokens.weight" in sd: | |
| sd["lm_head.weight"] = sd["embed_tokens.weight"] # tied | |
| missing, unexpected = model.load_state_dict(sd, strict=False) | |
| missing = [m for m in missing if "inv_freq" not in m] | |
| unexpected = [u for u in unexpected if "inv_freq" not in u] | |
| if missing or unexpected: | |
| raise RuntimeError(f"state_dict mismatch:\n missing={missing}\n unexpected={unexpected}") | |
| return model.to(device=device, dtype=dtype).eval() | |
| # --------------------------------------------------------------------------- # | |
| # Optional HuggingFace wrapper — AutoModelForCausalLM.from_pretrained + # | |
| # model.generate() (trust_remote_code=True). Cache-free recompute for # | |
| # version-robustness; NordicFlash.translate() is the fast standalone path. # | |
| # --------------------------------------------------------------------------- # | |
| try: | |
| from transformers import PretrainedConfig, PreTrainedModel | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| class NordicFlashConfig(PretrainedConfig): | |
| model_type = "nordic_flash" | |
| def __init__(self, vocab_size=65008, hidden_dim=1280, num_layers=18, | |
| num_attention_heads=16, num_kv_heads=4, head_dim=80, | |
| ffn_intermediate=3584, rope_theta=500000.0, | |
| rope_partial_factor=0.25, rms_norm_eps=1e-6, max_position=4096, | |
| dynaconv_kernel=14, dynaconv_head_dim=80, dynaconv_num_heads=16, | |
| layer_types=None, **kwargs): | |
| self.vocab_size = vocab_size | |
| self.hidden_dim = hidden_dim | |
| self.num_layers = num_layers | |
| self.num_attention_heads = num_attention_heads | |
| self.num_kv_heads = num_kv_heads | |
| self.head_dim = head_dim | |
| self.ffn_intermediate = ffn_intermediate | |
| self.rope_theta = rope_theta | |
| self.rope_partial_factor = rope_partial_factor | |
| self.rms_norm_eps = rms_norm_eps | |
| self.max_position = max_position | |
| self.dynaconv_kernel = dynaconv_kernel | |
| self.dynaconv_head_dim = dynaconv_head_dim | |
| self.dynaconv_num_heads = dynaconv_num_heads | |
| self.layer_types = layer_types or _default_layer_types() | |
| self.num_hidden_layers = num_layers | |
| self.hidden_size = hidden_dim | |
| self.num_key_value_heads = num_kv_heads | |
| self.max_position_embeddings = max_position | |
| kwargs.setdefault("tie_word_embeddings", True) | |
| kwargs.setdefault("use_cache", False) | |
| super().__init__(**kwargs) | |
| def to_flash(self): | |
| return FlashConfig( | |
| vocab_size=self.vocab_size, hidden_dim=self.hidden_dim, | |
| num_layers=self.num_layers, num_attention_heads=self.num_attention_heads, | |
| num_kv_heads=self.num_kv_heads, head_dim=self.head_dim, | |
| ffn_intermediate=self.ffn_intermediate, rope_theta=self.rope_theta, | |
| rope_partial_factor=self.rope_partial_factor, rms_norm_eps=self.rms_norm_eps, | |
| max_position=self.max_position, dynaconv_kernel=self.dynaconv_kernel, | |
| dynaconv_head_dim=self.dynaconv_head_dim, dynaconv_num_heads=self.dynaconv_num_heads, | |
| layer_types=list(self.layer_types)) | |
| class NordicFlashForCausalLM(PreTrainedModel): | |
| config_class = NordicFlashConfig | |
| _supports_cache_class = False | |
| _no_split_modules = ["DecoderBlock"] | |
| def __init__(self, config: "NordicFlashConfig"): | |
| super().__init__(config) | |
| cfg = config.to_flash() | |
| self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.hidden_dim) | |
| self.layers = nn.ModuleList(DecoderBlock(cfg, t) for t in cfg.layer_types) | |
| self.final_norm = RMSNorm(cfg.hidden_dim, cfg.rms_norm_eps) | |
| self.lm_head = nn.Linear(cfg.hidden_dim, cfg.vocab_size, bias=False) | |
| self.post_init() | |
| def get_input_embeddings(self): return self.embed_tokens | |
| def set_input_embeddings(self, v): self.embed_tokens = v | |
| def get_output_embeddings(self): return self.lm_head | |
| def forward(self, input_ids=None, attention_mask=None, past_key_values=None, | |
| use_cache=None, labels=None, return_dict=True, **kwargs): | |
| x = self.embed_tokens(input_ids) | |
| for layer in self.layers: | |
| x = layer(x) | |
| logits = self.lm_head(self.final_norm(x)) | |
| return CausalLMOutputWithPast(logits=logits) | |
| def prepare_inputs_for_generation(self, input_ids, **kwargs): | |
| return {"input_ids": input_ids, "use_cache": False} | |
| except ImportError: | |
| pass | |