joelhenwang commited on
Commit
0ef192a
·
verified ·
1 Parent(s): cb28e4d

OdinNext-138M-Base: EMA weights (101.6B-token dolmino base)

Browse files
README.md ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ library_name: transformers
6
+ pipeline_tag: text-generation
7
+ tags:
8
+ - odinnext
9
+ - hgrn2
10
+ - linear-attention
11
+ - recurrent
12
+ - causal-lm
13
+ - custom_code
14
+ - base-model
15
+ - fp16
16
+ - amd
17
+ - rocm
18
+ - arxiv:2404.07904
19
+ - arxiv:2605.06546
20
+ - arxiv:2407.12665
21
+ - arxiv:2506.14202
22
+ ---
23
+
24
+ # OdinNext-138M-Base
25
+
26
+ **OdinNext** is a 138M-parameter causal language model that replaces softmax
27
+ self-attention with an **HGRN2-style gated linear recurrence**. This repository
28
+ is the **base pretrained model** — trained from scratch on ~101.6B tokens of
29
+ curated data (the Dolmino mix) on two AMD Strix Halo (gfx1151) machines.
30
+
31
+ This is a **base model**: it completes and continues text. It is **not** an
32
+ instruction-tuned or chat model — no SFT, DPO, RLHF, or chat template. Those
33
+ stages are in progress and will ship as a separate `*-Instruct` repository.
34
+
35
+ - **Repo:** `joelhenwang/OdinNext-138M-Base`
36
+ - **`main`:** EMA-shadowed weights (decay 0.999), recommended.
37
+ - **`live`:** raw training weights at the same step.
38
+ - **Context window:** 2,048 tokens in the released inference code.
39
+ - **License:** Apache-2.0.
40
+
41
+ > Uses custom Transformers code. Loading with `trust_remote_code=True` executes
42
+ > Python from this repo. Review the files or pin a commit before trusting it.
43
+
44
+ ## At a glance
45
+
46
+ | Item | Value |
47
+ |---|---:|
48
+ | Unique tied parameters | **138,449,696** |
49
+ | Non-embedding parameters | **113,283,872** |
50
+ | Layers | 16 |
51
+ | Hidden size | 768 |
52
+ | Heads | 6 |
53
+ | Head state dims | 128 × 128 per head |
54
+ | FFN inner size | 2,048 |
55
+ | Vocabulary | 32,768 custom BPE tokens |
56
+ | Max sequence length | 2,048 |
57
+ | Checkpoint dtype | fp16 |
58
+ | Architecture | HGRN2 recurrence + alternating RoPE + SwiGLU² FFN + ZCRMSNorm |
59
+ | Cache type | Fixed-size recurrent state, not a growing KV cache |
60
+
61
+ ## Architecture
62
+
63
+ Decoder-only causal LM, 16 identical pre-norm blocks:
64
+
65
+ ```text
66
+ x = x + sigmoid(gate_attn) * HGRN2(ZCRMSNorm(x))
67
+ x = x + sigmoid(gate_ffn) * SwiGLU²(ZCRMSNorm(x))
68
+ ```
69
+
70
+ The HGRN2 recurrent state updates per token as:
71
+
72
+ ```text
73
+ S_t = diag(exp(g_t)) S_{t-1} + k_t ⊗ v_t
74
+ o_t = q_t S_t
75
+ ```
76
+
77
+ with a per-layer state shaped `[B, n_heads, head_f_dim, head_i_dim]` =
78
+ `[B, 6, 128, 128]`. This state is **constant in size with respect to context
79
+ length**, giving O(1)-per-token decoding rather than a growing KV cache.
80
+
81
+ **Hybrid RoPE:** even layers (0, 2, …, 14) apply RoPE to q/k (θ = 100,000);
82
+ odd layers are position-free. Tied embedding / LM head. No linear biases.
83
+
84
+ ## Memory: recurrent state vs Transformer KV cache
85
+
86
+ For batch size 1 in fp16 the recurrent state is constant:
87
+
88
+ ```text
89
+ layers × heads × head_f_dim × head_i_dim × bytes
90
+ = 16 × 6 × 128 × 128 × 2 = 3,145,728 bytes ≈ 3.0 MiB
91
+ ```
92
+
93
+ independent of generated length (the pure-PyTorch fallback promotes the scan
94
+ state to fp32, ≈ 6.0 MiB). A same-depth fp16 Transformer KV cache would grow
95
+ linearly (≈ 48 MiB at 1K tokens, ≈ 768 MiB at 16K). This is a cache-state
96
+ comparison only, not a claim about total memory or usable context.
97
+
98
+ ## Training snapshot
99
+
100
+ | Field | Value |
101
+ |---|---|
102
+ | Data | Dolmino mix (~101.6B tokens, odin-32k tokenizer) |
103
+ | Hardware | 2× AMD Strix Halo / gfx1151, ROCm 7.13 |
104
+ | Interconnect | Thunderbolt 4, DDP over gloo |
105
+ | Precision | fp16 + GradScaler |
106
+ | Optimizers | NorMuon (2D tensors) + AdamW (1D / embeddings) |
107
+ | LR | peak 8e-4, warmup, cosine decay |
108
+ | Stabilization | z-loss 1e-4, attention soft-cap 50, EMA decay 0.999 |
109
+ | Curriculum | Phase 1: Token-Superposition Training (bag-size 4) + DiffusionBlocks (block-wise) for ~24K steps; Phase 2: standard end-to-end autoregressive recovery |
110
+ | Released weights | `main` = `ema_state_dict`; `live` = raw online weights |
111
+
112
+ The two-phase curriculum trains most of the budget under a block-wise
113
+ DiffusionBlocks + token-superposition objective for throughput, then recovers
114
+ ordinary left-to-right generation with a standard end-to-end phase. The
115
+ released weights are from the end-to-end recovery phase and produce coherent
116
+ continuations.
117
+
118
+ ## What this model is good for
119
+
120
+ - Text continuation and completion in English.
121
+ - Research on compact recurrent / linear-attention LMs and fixed-state decoding.
122
+ - A base for instruction tuning, alignment, and context extension.
123
+
124
+ Do **not** use it for chat / instruction following (not tuned yet), safety-
125
+ sensitive generation, or benchmark claims without running your own evaluation.
126
+
127
+ ## Usage
128
+
129
+ ```bash
130
+ pip install "transformers>=4.46" torch safetensors
131
+ ```
132
+
133
+ ```python
134
+ import torch
135
+ from transformers import AutoModelForCausalLM, AutoTokenizer
136
+
137
+ repo = "joelhenwang/OdinNext-138M-Base"
138
+ revision = "main" # EMA weights; pin a commit for reproducibility
139
+
140
+ device = "cuda" if torch.cuda.is_available() else "cpu"
141
+ dtype = torch.float16 if device == "cuda" else torch.float32
142
+
143
+ tok = AutoTokenizer.from_pretrained(repo, revision=revision)
144
+ model = AutoModelForCausalLM.from_pretrained(
145
+ repo, revision=revision, trust_remote_code=True, torch_dtype=dtype,
146
+ ).to(device).eval()
147
+
148
+ prompt = "The discovery of penicillin"
149
+ inputs = tok(prompt, return_tensors="pt").to(device)
150
+ remaining = model.config.max_position_embeddings - inputs.input_ids.shape[1]
151
+ with torch.inference_mode():
152
+ out = model.generate(
153
+ **inputs,
154
+ max_new_tokens=max(0, min(100, remaining)),
155
+ do_sample=True, temperature=0.8, top_p=0.95, repetition_penalty=1.1,
156
+ pad_token_id=tok.pad_token_id, use_cache=True,
157
+ )
158
+ print(tok.decode(out[0], skip_special_tokens=True))
159
+ ```
160
+
161
+ ### Batching guidance
162
+
163
+ The recurrent scan does not apply an attention mask. For correct batched
164
+ generation: avoid left padding, prefer same-length prompts, and verify batched
165
+ output against single-sample output before relying on it. Single-prompt
166
+ generation is the safest path.
167
+
168
+ ## Limitations
169
+
170
+ - **Base model only:** no instruction tuning, alignment, or chat template.
171
+ - **No safety training:** outputs can be biased, false, or incoherent.
172
+ - **Hard 2,048-token cap:** recurrent state is constant, but the released RoPE
173
+ cache limits cumulative positions to 2,048.
174
+ - **`attention_mask` ignored** in the backbone; padding affects recurrent state.
175
+ - **English-focused;** multilingual / code ability is uncharacterized.
176
+ - **Formal benchmarks not published in this card yet.** Treat quality as
177
+ preliminary and run your own evaluation.
178
+
179
+ ## Revisions
180
+
181
+ - `main`: EMA-shadowed weights (decay 0.999), recommended for evaluation.
182
+ - `live`: raw training weights at the same step.
183
+
184
+ Pin a commit hash rather than a moving branch for reproducible experiments.
185
+
186
+ ## Citation
187
+
188
+ ```bibtex
189
+ @misc{odinnext_138m_base_2026,
190
+ title = {OdinNext-138M-Base},
191
+ author = {Wang, Joel},
192
+ year = {2026},
193
+ howpublished = {\url{https://huggingface.co/joelhenwang/OdinNext-138M-Base}},
194
+ note = {138M HGRN2 recurrent language-model base checkpoint}
195
+ }
196
+ ```
197
+
198
+ ## References
199
+
200
+ - Zhen Qin et al. **HGRN2: Gated Linear RNNs with State Expansion.** arXiv:2404.07904.
201
+ - Bowen Peng et al. **Efficient Pre-Training with Token Superposition.** arXiv:2605.06546.
202
+ - Chenze Shao et al. **Patch-Level Training for Large Language Models.** arXiv:2407.12665.
203
+ - Makoto Shing et al. **DiffusionBlocks: Block-wise Neural Network Training via Diffusion Interpretation.** arXiv:2506.14202.
_hgrn2_fallback.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The OdinNext authors.
3
+ # Licensed under the Apache License, Version 2.0.
4
+ """Pure-PyTorch HGRN2 recurrence — slow fallback when flash-linear-attention
5
+ (`fla`) is unavailable.
6
+
7
+ The `fla` library provides Triton/CUDA kernels for `chunk_gla` (chunk-wise
8
+ parallel scan over T) and `fused_recurrent_gla` (token-by-token serial scan).
9
+ On platforms without those kernels (CPU, non-CUDA/non-ROCm GPUs) we provide
10
+ a reference implementation here.
11
+
12
+ Speed: ~10-30x slower than `fla` at training shapes; comparable for
13
+ single-token decode (since both are serial). Numerical match: bitwise on
14
+ fp32, within fp16 noise on fp16.
15
+
16
+ The recurrence (per head):
17
+ S_t = diag(exp(g_t)) @ S_{t-1} + k_t.unsqueeze(-1) @ v_t.unsqueeze(-2)
18
+ o_t = q_t @ S_t
19
+
20
+ Shapes (matching `fla.ops.gla.chunk_gla`):
21
+ q: [B, T, H, K] (K = head_f_dim, e.g. 128)
22
+ k: [B, T, H, K]
23
+ g: [B, T, H, K] (already in log-space, expected to be <= 0)
24
+ v: [B, T, H, V] (V = head_i_dim, e.g. 128)
25
+ -> o: [B, T, H, V]
26
+ final_state: [B, H, K, V] if output_final_state else None
27
+ """
28
+
29
+ from typing import Optional, Tuple
30
+
31
+ import torch
32
+
33
+
34
+ def chunk_gla(
35
+ q: torch.Tensor,
36
+ k: torch.Tensor,
37
+ v: torch.Tensor,
38
+ g: torch.Tensor,
39
+ initial_state: Optional[torch.Tensor] = None,
40
+ output_final_state: bool = False,
41
+ **_unused,
42
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
43
+ """Pure-PyTorch chunk_gla replacement.
44
+
45
+ Implements a serial (token-by-token) scan. We promote internals to fp32
46
+ to keep the cumulative product of decays numerically sane over long T.
47
+ """
48
+ B, T, H, K = q.shape
49
+ V = v.shape[-1]
50
+ device = q.device
51
+ in_dtype = q.dtype
52
+
53
+ # Promote scan internals to fp32 for stability (matches fla behavior).
54
+ q32 = q.float()
55
+ k32 = k.float()
56
+ v32 = v.float()
57
+ g32 = g.float()
58
+
59
+ if initial_state is None:
60
+ S = torch.zeros(B, H, K, V, device=device, dtype=torch.float32)
61
+ else:
62
+ S = initial_state.to(dtype=torch.float32)
63
+
64
+ out = torch.empty(B, T, H, V, device=device, dtype=torch.float32)
65
+
66
+ # Serial scan. exp(g_t) decays state element-wise along K.
67
+ # k_t outer v_t -> [B, H, K, V] additive update.
68
+ for t in range(T):
69
+ decay = g32[:, t].exp().unsqueeze(-1) # [B, H, K, 1]
70
+ S = decay * S + k32[:, t].unsqueeze(-1) * v32[:, t].unsqueeze(-2)
71
+ # o_t = q_t (1xK) @ S (KxV) per head
72
+ out[:, t] = (q32[:, t].unsqueeze(-2) @ S).squeeze(-2) # [B, H, V]
73
+
74
+ out = out.to(in_dtype)
75
+ if output_final_state:
76
+ return out, S
77
+ return out, None
78
+
79
+
80
+ def fused_recurrent_gla(
81
+ q: torch.Tensor,
82
+ k: torch.Tensor,
83
+ v: torch.Tensor,
84
+ gk: torch.Tensor,
85
+ initial_state: Optional[torch.Tensor] = None,
86
+ output_final_state: bool = True,
87
+ **_unused,
88
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
89
+ """Pure-PyTorch single-token (or short-T) recurrence.
90
+
91
+ `fla.ops.gla.fused_recurrent_gla` is what OdinNext.generate uses for
92
+ O(1) per-token decode. The signature matches: `gk` = log-decay (instead
93
+ of `g`). We reuse `chunk_gla` internals — they are mathematically the
94
+ same scan, just packaged with different defaults for kernel selection
95
+ in fla.
96
+ """
97
+ return chunk_gla(
98
+ q=q, k=k, v=v, g=gk,
99
+ initial_state=initial_state,
100
+ output_final_state=output_final_state,
101
+ )
config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "odinnext",
3
+ "architectures": [
4
+ "OdinNextForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_odinnext.OdinNextConfig",
8
+ "AutoModelForCausalLM": "modeling_odinnext.OdinNextForCausalLM"
9
+ },
10
+ "vocab_size": 32768,
11
+ "d_model": 768,
12
+ "n_layers": 16,
13
+ "n_heads": 6,
14
+ "ffn_inner": 2048,
15
+ "max_seq_len": 2048,
16
+ "rope_theta": 100000.0,
17
+ "tie_embeddings": true,
18
+ "tie_word_embeddings": true,
19
+ "use_cache": true,
20
+ "torch_dtype": "float16",
21
+ "bos_token_id": 0,
22
+ "eos_token_id": 0,
23
+ "pad_token_id": 1,
24
+ "hidden_size": 768,
25
+ "num_hidden_layers": 16,
26
+ "num_attention_heads": 6,
27
+ "intermediate_size": 2048,
28
+ "max_position_embeddings": 2048,
29
+ "_training_step": 5000,
30
+ "_total_tokens": 5243928576,
31
+ "_weights_source": "ema_state_dict"
32
+ }
configuration_odinnext.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The OdinNext authors.
3
+ # Licensed under the Apache License, Version 2.0.
4
+ """OdinNext model configuration."""
5
+
6
+ from transformers import PretrainedConfig
7
+
8
+
9
+ class OdinNextConfig(PretrainedConfig):
10
+ r"""Configuration class for [`OdinNextForCausalLM`].
11
+
12
+ OdinNext is a 138M-parameter HGRN2+RoPE hybrid causal language model.
13
+ The architecture interleaves two layer types:
14
+ * Even layers (0, 2, 4, ..., 14): HGRN2 gated linear recurrence with
15
+ rotary position embeddings (RoPE) on q/k.
16
+ * Odd layers (1, 3, 5, ..., 15): the same HGRN2 recurrence WITHOUT
17
+ positional encoding (position-free, generalizes to any length).
18
+
19
+ HGRN2 gives O(T) training and O(1) per-token inference: the per-layer
20
+ recurrent state has a fixed size independent of context length.
21
+
22
+ Args:
23
+ vocab_size (`int`, *optional*, defaults to 32768):
24
+ Vocabulary size of the OdinNext model.
25
+ d_model (`int`, *optional*, defaults to 768):
26
+ Hidden size of the residual stream.
27
+ n_layers (`int`, *optional*, defaults to 16):
28
+ Number of transformer-style blocks.
29
+ n_heads (`int`, *optional*, defaults to 6):
30
+ Number of recurrence heads. Per-head expand dim is
31
+ `d_model // n_heads = 128` for the default configuration.
32
+ ffn_inner (`int`, *optional*, defaults to 2048):
33
+ SwiGLU2 inner dimension.
34
+ max_seq_len (`int`, *optional*, defaults to 2048):
35
+ Maximum sequence length the RoPE cache covers. Generation past
36
+ this position raises (extend by raising and re-instantiating).
37
+ rope_theta (`float`, *optional*, defaults to 100000.0):
38
+ RoPE base frequency. Even layers only.
39
+ tie_embeddings (`bool`, *optional*, defaults to `True`):
40
+ Tie input embedding matrix and output LM-head weight.
41
+ initializer_range (`float`, *optional*, defaults to 0.02):
42
+ Unused at inference; recorded for parity with HF conventions.
43
+ bos_token_id (`int`, *optional*, defaults to 0):
44
+ Same as eos for this tokenizer (`<|endoftext|>`).
45
+ eos_token_id (`int`, *optional*, defaults to 0):
46
+ `<|endoftext|>` token id.
47
+ pad_token_id (`int`, *optional*, defaults to 1):
48
+ `<|pad|>` token id in the odin-32k tokenizer.
49
+ use_cache (`bool`, *optional*, defaults to `True`):
50
+ Whether to return per-layer recurrent states from `forward()`,
51
+ and whether `generate()` should consume them. The "cache" here
52
+ is a list of fixed-size HGRN2 states, NOT a growing KV cache.
53
+
54
+ Example:
55
+
56
+ ```python
57
+ >>> from transformers import AutoConfig
58
+ >>> config = AutoConfig.from_pretrained(
59
+ ... "joelhenwang/OdinNext-138M-Early-Checkpoint",
60
+ ... trust_remote_code=True,
61
+ ... )
62
+ >>> config.d_model
63
+ 768
64
+ ```
65
+ """
66
+
67
+ model_type = "odinnext"
68
+ keys_to_ignore_at_inference = ["past_key_values"]
69
+
70
+ def __init__(
71
+ self,
72
+ vocab_size: int = 32768,
73
+ d_model: int = 768,
74
+ n_layers: int = 16,
75
+ n_heads: int = 6,
76
+ ffn_inner: int = 2048,
77
+ max_seq_len: int = 2048,
78
+ rope_theta: float = 100000.0,
79
+ tie_embeddings: bool = True,
80
+ initializer_range: float = 0.02,
81
+ bos_token_id: int = 0,
82
+ eos_token_id: int = 0,
83
+ pad_token_id: int = 1,
84
+ use_cache: bool = True,
85
+ **kwargs,
86
+ ):
87
+ self.vocab_size = vocab_size
88
+ self.d_model = d_model
89
+ self.n_layers = n_layers
90
+ self.n_heads = n_heads
91
+ self.ffn_inner = ffn_inner
92
+ self.max_seq_len = max_seq_len
93
+ self.rope_theta = rope_theta
94
+ self.tie_embeddings = tie_embeddings
95
+ self.initializer_range = initializer_range
96
+ self.use_cache = use_cache
97
+
98
+ # Common HF aliases — many libraries (lm-eval-harness, vLLM compat
99
+ # layers, etc.) reach for these names. Provide them as direct
100
+ # passthroughs so external tooling has a chance of working.
101
+ self.hidden_size = d_model
102
+ self.num_hidden_layers = n_layers
103
+ self.num_attention_heads = n_heads
104
+ self.intermediate_size = ffn_inner
105
+ self.max_position_embeddings = max_seq_len
106
+
107
+ # Strip keys we are about to pass explicitly so they don't double up
108
+ # via **kwargs (config.json may carry duplicates).
109
+ kwargs.pop("tie_word_embeddings", None)
110
+ kwargs.pop("bos_token_id", None)
111
+ kwargs.pop("eos_token_id", None)
112
+ kwargs.pop("pad_token_id", None)
113
+
114
+ super().__init__(
115
+ bos_token_id=bos_token_id,
116
+ eos_token_id=eos_token_id,
117
+ pad_token_id=pad_token_id,
118
+ tie_word_embeddings=tie_embeddings,
119
+ **kwargs,
120
+ )
generation_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 0,
3
+ "eos_token_id": 0,
4
+ "pad_token_id": 1,
5
+ "max_new_tokens": 128,
6
+ "do_sample": true,
7
+ "temperature": 0.8,
8
+ "top_p": 0.95,
9
+ "repetition_penalty": 1.1,
10
+ "use_cache": true
11
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bfc1fdd190627224dedcaa0f8894b7efdcb4e8c2207fd86de2a649c1e1fa7f56
3
+ size 276917608
modeling_odinnext.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The OdinNext authors.
3
+ # Licensed under the Apache License, Version 2.0.
4
+ """OdinNext: 138M HGRN2+RoPE hybrid causal language model.
5
+
6
+ This is a self-contained HuggingFace `trust_remote_code=True` port of the
7
+ production OdinNext model used to train the 6.84B-token early checkpoint.
8
+ The training-time machinery (DiffusionBlocks, TST, gate-absorption,
9
+ torch.compile zone helpers) is dropped — only the inference path remains.
10
+
11
+ Architecture summary:
12
+ * 16 layers, d=768, 6 heads, ffn=2048, vocab=32768.
13
+ * Even layers (0,2,...,14) get RoPE on q/k.
14
+ * Odd layers (1,3,...,15) are position-free recurrent.
15
+ * SwiGLU2 FFN: silu(gate)^2 * up.
16
+ * ZCRMSNorm normalization, gated residuals (frozen at training time).
17
+ * Tied input/output embeddings.
18
+ * HGRN2 recurrence: O(T) train, O(1) per-token decode.
19
+
20
+ Hardware notes:
21
+ * Uses `flash-linear-attention` (`fla`) Triton kernels when available.
22
+ Falls back to a pure-PyTorch implementation (~10-30x slower) otherwise,
23
+ so the model loads on any backend including CPU.
24
+ * Trained in fp16 on AMD Strix Halo (gfx1151, RDNA 3.5, ROCm 7.13).
25
+ fp16 is the recommended inference dtype. bf16 was never validated on
26
+ this checkpoint.
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import math
32
+ from typing import List, Optional, Tuple, Union
33
+
34
+ import torch
35
+ import torch.nn as nn
36
+ import torch.nn.functional as F
37
+
38
+ from transformers import PreTrainedModel
39
+ from transformers.modeling_outputs import CausalLMOutputWithPast
40
+
41
+ from .configuration_odinnext import OdinNextConfig
42
+
43
+ # ---------------------------------------------------------------------------
44
+ # HGRN2 kernel: prefer flash-linear-attention, fall back to pure PyTorch
45
+ # ---------------------------------------------------------------------------
46
+
47
+ try:
48
+ from fla.ops.gla import chunk_gla as _chunk_gla
49
+ from fla.ops.gla import fused_recurrent_gla as _fused_recurrent_gla
50
+
51
+ # `fla.ops.gla.chunk.ChunkGLAFunction` is decorated with
52
+ # @torch.compiler.disable. Marking it allow_in_graph lets Dynamo treat
53
+ # it as an opaque leaf op, preventing graph breaks if the user does
54
+ # `torch.compile(model)`. Best-effort, ignored if internals shift.
55
+ try:
56
+ from fla.ops.gla.chunk import ChunkGLAFunction
57
+ torch.compiler.allow_in_graph(ChunkGLAFunction)
58
+ except Exception:
59
+ pass
60
+
61
+ _HAS_FLA = True
62
+ except Exception: # ImportError, missing Triton, no CUDA/ROCm, ...
63
+ from ._hgrn2_fallback import chunk_gla as _chunk_gla
64
+ from ._hgrn2_fallback import fused_recurrent_gla as _fused_recurrent_gla
65
+ _HAS_FLA = False
66
+
67
+
68
+ # ---------------------------------------------------------------------------
69
+ # Building blocks
70
+ # ---------------------------------------------------------------------------
71
+
72
+
73
+ class ZCRMSNorm(nn.Module):
74
+ """Zero-Centered RMSNorm.
75
+
76
+ Stored weight is initialized to 1.0; F.rms_norm sees a leaf parameter
77
+ directly. Mathematically equivalent to RMSNorm with `gamma = weight - 1`.
78
+ """
79
+
80
+ def __init__(self, dim: int, eps: float = 1e-6):
81
+ super().__init__()
82
+ self.eps = eps
83
+ self.weight = nn.Parameter(torch.ones(dim))
84
+ self._normalized_shape = (dim,)
85
+
86
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
87
+ return F.rms_norm(x, self._normalized_shape, self.weight, self.eps)
88
+
89
+
90
+ class SwiGLU2(nn.Module):
91
+ """SwiGLU squared FFN: silu(gate)^2 * up -> down."""
92
+
93
+ def __init__(self, d_model: int, ffn_inner: int):
94
+ super().__init__()
95
+ self.w_gate_up = nn.Linear(d_model, 2 * ffn_inner, bias=False)
96
+ self.w_down = nn.Linear(ffn_inner, d_model, bias=False)
97
+
98
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
99
+ gate, up = self.w_gate_up(x).chunk(2, dim=-1)
100
+ return self.w_down(F.silu(gate).square() * up)
101
+
102
+
103
+ def _apply_rope(
104
+ x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
105
+ ) -> torch.Tensor:
106
+ """Apply RoPE to x[B,T,H,D] using real arithmetic.
107
+
108
+ cos/sin: [1, T, 1, D/2] pre-broadcast.
109
+ """
110
+ x_even = x[..., 0::2]
111
+ x_odd = x[..., 1::2]
112
+ out_even = x_even * cos - x_odd * sin
113
+ out_odd = x_even * sin + x_odd * cos
114
+ return torch.stack([out_even, out_odd], dim=-1).flatten(-2)
115
+
116
+
117
+ class OdinNextAttention(nn.Module):
118
+ """HGRN2 attention with optional RoPE on q/k."""
119
+
120
+ def __init__(
121
+ self,
122
+ d_model: int = 768,
123
+ n_heads: int = 6,
124
+ expand_ratio: Optional[int] = None,
125
+ use_rope: bool = True,
126
+ ):
127
+ super().__init__()
128
+ self.d_model = d_model
129
+ self.n_heads = n_heads
130
+ if expand_ratio is None:
131
+ expand_ratio = d_model // n_heads
132
+ self.expand_ratio = expand_ratio
133
+ self.head_f_dim = expand_ratio
134
+ self.head_i_dim = d_model // n_heads
135
+ self.forget_dim = n_heads * expand_ratio
136
+ self.use_rope = use_rope
137
+
138
+ self.q_proj = nn.Linear(d_model, self.forget_dim, bias=False)
139
+ self.f_proj = nn.Linear(d_model, self.forget_dim, bias=False)
140
+ self.i_proj = nn.Linear(d_model, d_model, bias=False)
141
+ self.g_norm = ZCRMSNorm(d_model)
142
+ self.o_proj = nn.Linear(d_model, d_model, bias=False)
143
+
144
+ def forward(
145
+ self,
146
+ x: torch.Tensor,
147
+ cos: Optional[torch.Tensor] = None,
148
+ sin: Optional[torch.Tensor] = None,
149
+ recurrent_state: Optional[torch.Tensor] = None,
150
+ output_state: bool = False,
151
+ use_recurrent_kernel: bool = False,
152
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
153
+ """
154
+ Args:
155
+ x: [B, T, D] hidden states.
156
+ cos, sin: RoPE caches if `use_rope`, else ignored.
157
+ recurrent_state: optional [B, H, K, V] HGRN2 state to seed the scan.
158
+ output_state: if True, return the final HGRN2 state alongside output.
159
+ use_recurrent_kernel: if True (single-token decode), call the
160
+ fused recurrent kernel; otherwise call chunk_gla.
161
+ """
162
+ B, T, D = x.shape
163
+
164
+ q = F.silu(self.q_proj(x))
165
+ forget_logits = self.f_proj(x)
166
+ g = F.logsigmoid(forget_logits)
167
+ k = torch.sigmoid(-forget_logits)
168
+ v = self.i_proj(x)
169
+
170
+ q = q.view(B, T, self.n_heads, self.head_f_dim)
171
+ k = k.view(B, T, self.n_heads, self.head_f_dim)
172
+ g = g.view(B, T, self.n_heads, self.head_f_dim)
173
+ v = v.view(B, T, self.n_heads, self.head_i_dim)
174
+
175
+ if self.use_rope and cos is not None:
176
+ q = _apply_rope(q, cos, sin)
177
+ k = _apply_rope(k, cos, sin)
178
+
179
+ if use_recurrent_kernel:
180
+ o, final_state = _fused_recurrent_gla(
181
+ q=q, k=k, v=v, gk=g,
182
+ initial_state=recurrent_state,
183
+ output_final_state=True,
184
+ )
185
+ else:
186
+ o, final_state = _chunk_gla(
187
+ q=q, k=k, v=v, g=g,
188
+ initial_state=recurrent_state,
189
+ output_final_state=output_state,
190
+ )
191
+
192
+ o = o.reshape(B, T, D)
193
+ o = self.g_norm(o)
194
+ o = self.o_proj(o)
195
+
196
+ if output_state:
197
+ return o, final_state
198
+ return o, None
199
+
200
+
201
+ class OdinNextBlock(nn.Module):
202
+ """Pre-norm block with gated residuals.
203
+
204
+ Gates were absorbed and frozen at training time: `gate_attn` and
205
+ `gate_ffn` are stored as scalars whose `sigmoid()` ≈ 1 by the time of
206
+ this checkpoint. They remain in the state_dict for compatibility.
207
+ """
208
+
209
+ def __init__(
210
+ self,
211
+ d_model: int,
212
+ n_heads: int,
213
+ ffn_inner: int,
214
+ use_rope: bool = True,
215
+ ):
216
+ super().__init__()
217
+ self.pre_norm = ZCRMSNorm(d_model)
218
+ self.attn = OdinNextAttention(
219
+ d_model=d_model, n_heads=n_heads, use_rope=use_rope
220
+ )
221
+ self.ffn_norm = ZCRMSNorm(d_model)
222
+ self.ffn = SwiGLU2(d_model, ffn_inner)
223
+ self.gate_attn = nn.Parameter(torch.zeros(1))
224
+ self.gate_ffn = nn.Parameter(torch.zeros(1))
225
+
226
+ def forward(
227
+ self,
228
+ x: torch.Tensor,
229
+ cos: Optional[torch.Tensor] = None,
230
+ sin: Optional[torch.Tensor] = None,
231
+ recurrent_state: Optional[torch.Tensor] = None,
232
+ output_state: bool = False,
233
+ use_recurrent_kernel: bool = False,
234
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
235
+ attn_out, new_state = self.attn(
236
+ self.pre_norm(x),
237
+ cos=cos, sin=sin,
238
+ recurrent_state=recurrent_state,
239
+ output_state=output_state,
240
+ use_recurrent_kernel=use_recurrent_kernel,
241
+ )
242
+ x = x + torch.sigmoid(self.gate_attn) * attn_out
243
+ x = x + torch.sigmoid(self.gate_ffn) * self.ffn(self.ffn_norm(x))
244
+ return x, new_state
245
+
246
+
247
+ # ---------------------------------------------------------------------------
248
+ # OdinNext recurrent-state cache
249
+ # ---------------------------------------------------------------------------
250
+
251
+
252
+ class OdinNextCache:
253
+ """Container for HGRN2 recurrent states across all layers.
254
+
255
+ Wraps `List[Optional[Tensor]]` (one per layer, each [B, H, K, V]) with
256
+ just enough surface to satisfy HuggingFace `generate()`'s expectations
257
+ for `past_key_values`. Importantly: cache size is independent of T —
258
+ it is the per-layer hidden-state matrix S, not a growing K/V tape.
259
+
260
+ Also tracks `seen_tokens`, the number of input positions the cache has
261
+ consumed so far, which OdinNext uses to look up the correct RoPE
262
+ position offset during decode.
263
+ """
264
+
265
+ def __init__(self, n_layers: int):
266
+ self.n_layers = n_layers
267
+ self.states: List[Optional[torch.Tensor]] = [None] * n_layers
268
+ self.seen_tokens: int = 0
269
+
270
+ def __len__(self) -> int:
271
+ return self.n_layers
272
+
273
+ def __getitem__(self, idx: int) -> Optional[torch.Tensor]:
274
+ return self.states[idx]
275
+
276
+ def __setitem__(self, idx: int, value: Optional[torch.Tensor]) -> None:
277
+ self.states[idx] = value
278
+
279
+ def __iter__(self):
280
+ return iter(self.states)
281
+
282
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
283
+ return self.seen_tokens
284
+
285
+ def get_max_length(self) -> Optional[int]:
286
+ return None # HGRN2 has no hard cache length cap
287
+
288
+ def update_seen(self, n_new_tokens: int) -> None:
289
+ self.seen_tokens += n_new_tokens
290
+
291
+ def to(self, device: torch.device) -> "OdinNextCache":
292
+ for i, s in enumerate(self.states):
293
+ if s is not None:
294
+ self.states[i] = s.to(device)
295
+ return self
296
+
297
+
298
+ # ---------------------------------------------------------------------------
299
+ # OdinNext PreTrainedModel: HF integration
300
+ # ---------------------------------------------------------------------------
301
+
302
+
303
+ class OdinNextPreTrainedModel(PreTrainedModel):
304
+ """Base class wiring up HF infrastructure for OdinNext."""
305
+
306
+ config_class = OdinNextConfig
307
+ base_model_prefix = "model"
308
+ supports_gradient_checkpointing = False
309
+ _no_split_modules = ["OdinNextBlock"]
310
+ _skip_keys_device_placement = "past_key_values"
311
+ _supports_cache_class = False # we use our own OdinNextCache
312
+
313
+ def _init_weights(self, module: nn.Module) -> None:
314
+ """Conservative init — at inference we only need to define defaults
315
+ in case someone constructs an OdinNext from scratch.
316
+ """
317
+ std = getattr(self.config, "initializer_range", 0.02)
318
+ if isinstance(module, nn.Linear):
319
+ nn.init.xavier_uniform_(module.weight)
320
+ if module.bias is not None:
321
+ nn.init.zeros_(module.bias)
322
+ elif isinstance(module, nn.Embedding):
323
+ module.weight.data.normal_(mean=0.0, std=std)
324
+
325
+
326
+ class OdinNextModel(OdinNextPreTrainedModel):
327
+ """Backbone (no LM head)."""
328
+
329
+ def __init__(self, config: OdinNextConfig):
330
+ super().__init__(config)
331
+ self.config = config
332
+
333
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.d_model)
334
+ self.layers = nn.ModuleList([
335
+ OdinNextBlock(
336
+ d_model=config.d_model,
337
+ n_heads=config.n_heads,
338
+ ffn_inner=config.ffn_inner,
339
+ use_rope=(i % 2 == 0),
340
+ )
341
+ for i in range(config.n_layers)
342
+ ])
343
+ self.final_norm = ZCRMSNorm(config.d_model)
344
+
345
+ # RoPE caches are lazy-built on first forward. Storing them as
346
+ # `register_buffer(..., persistent=False)` is incompatible with
347
+ # `from_pretrained(low_cpu_mem_usage=True)`: HF builds the model on
348
+ # the meta device and only materializes tensors that appear in the
349
+ # checkpoint. Non-persistent buffers are NOT in the checkpoint and
350
+ # so end up backed by uninitialized memory after meta -> real
351
+ # transfer. We side-step this entirely by computing cos/sin on the
352
+ # first forward, cached on the model object as plain attributes.
353
+ self._cos_cache: Optional[torch.Tensor] = None
354
+ self._sin_cache: Optional[torch.Tensor] = None
355
+
356
+ # Skip _init_weights here — we expect to load weights from a
357
+ # pretrained checkpoint immediately after construction.
358
+
359
+ def get_input_embeddings(self) -> nn.Embedding:
360
+ return self.tok_embeddings
361
+
362
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
363
+ self.tok_embeddings = value
364
+
365
+ # -----------------------------------------------------------------
366
+ # Forward
367
+ # -----------------------------------------------------------------
368
+
369
+ def _ensure_rope_cache(self, target_device: torch.device) -> None:
370
+ """Build the RoPE cos/sin caches on `target_device` if not already.
371
+
372
+ Cached as plain Python attributes (not buffers) to avoid HF's
373
+ `low_cpu_mem_usage=True` meta-device materialization issue with
374
+ non-persistent buffers.
375
+ """
376
+ need_build = (
377
+ self._cos_cache is None
378
+ or self._cos_cache.device != target_device
379
+ )
380
+ if not need_build:
381
+ return
382
+ head_f_dim = self.config.d_model // self.config.n_heads
383
+ half_dim = head_f_dim // 2
384
+ freqs = 1.0 / (
385
+ self.config.rope_theta
386
+ ** (
387
+ torch.arange(0, half_dim, dtype=torch.float32, device=target_device)
388
+ / half_dim
389
+ )
390
+ )
391
+ t = torch.arange(self.config.max_seq_len, dtype=torch.float32, device=target_device)
392
+ angles = torch.outer(t, freqs)
393
+ self._cos_cache = angles.cos()
394
+ self._sin_cache = angles.sin()
395
+
396
+ def _rope_slice(
397
+ self,
398
+ seq_len: int,
399
+ offset: int,
400
+ target_dtype: torch.dtype,
401
+ target_device: torch.device,
402
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
403
+ end = offset + seq_len
404
+ if end > self.config.max_seq_len:
405
+ raise ValueError(
406
+ f"Position {end} exceeds max_seq_len={self.config.max_seq_len}. "
407
+ "OdinNext was trained with a 2048-token RoPE cache."
408
+ )
409
+ self._ensure_rope_cache(target_device)
410
+ cos = self._cos_cache[offset:end].to(dtype=target_dtype)
411
+ sin = self._sin_cache[offset:end].to(dtype=target_dtype)
412
+ cos = cos.unsqueeze(0).unsqueeze(2) # [1, T, 1, D/2]
413
+ sin = sin.unsqueeze(0).unsqueeze(2)
414
+ return cos, sin
415
+
416
+ def forward(
417
+ self,
418
+ input_ids: torch.Tensor,
419
+ attention_mask: Optional[torch.Tensor] = None,
420
+ past_key_values: Optional[OdinNextCache] = None,
421
+ use_cache: Optional[bool] = None,
422
+ output_hidden_states: Optional[bool] = None,
423
+ return_dict: Optional[bool] = None,
424
+ **_unused,
425
+ ) -> Tuple[torch.Tensor, Optional[OdinNextCache]]:
426
+ """Backbone forward.
427
+
428
+ Returns `(hidden_states, past_key_values)`. The LM-head wrapper
429
+ (`OdinNextForCausalLM`) projects to logits.
430
+
431
+ Note: `attention_mask` is accepted for HF API compatibility but is
432
+ NOT used. HGRN2 is causal by construction (the recurrence is strictly
433
+ forward-in-time) and cannot honor a left-padded mask. For correct
434
+ results with batched generation, callers must right-pad and ensure
435
+ all sequences in a batch have valid tokens at every position they
436
+ process. Single-sequence generation is unaffected.
437
+ """
438
+ if use_cache is None:
439
+ use_cache = self.config.use_cache
440
+
441
+ B, T = input_ids.shape
442
+
443
+ # Determine if we're in single-token decode mode.
444
+ single_step = (T == 1) and (past_key_values is not None)
445
+
446
+ # RoPE position offset
447
+ if past_key_values is not None:
448
+ offset = past_key_values.seen_tokens
449
+ else:
450
+ offset = 0
451
+
452
+ h = self.tok_embeddings(input_ids)
453
+
454
+ # Prepare RoPE caches in the embedding's dtype.
455
+ cos, sin = self._rope_slice(
456
+ seq_len=T, offset=offset,
457
+ target_dtype=h.dtype, target_device=h.device,
458
+ )
459
+
460
+ # Coerce past_key_values to our expected type. HF generate may
461
+ # try to auto-instantiate a DynamicCache or pass a legacy tuple;
462
+ # we want strict OdinNextCache or None.
463
+ if past_key_values is not None and not isinstance(past_key_values, OdinNextCache):
464
+ past_key_values = None
465
+ if past_key_values is None and use_cache:
466
+ past_key_values = OdinNextCache(self.config.n_layers)
467
+
468
+ for i, layer in enumerate(self.layers):
469
+ prev_state = past_key_values[i] if past_key_values is not None else None
470
+ h, new_state = layer(
471
+ h,
472
+ cos=cos, sin=sin,
473
+ recurrent_state=prev_state,
474
+ output_state=use_cache,
475
+ use_recurrent_kernel=single_step,
476
+ )
477
+ if use_cache and past_key_values is not None:
478
+ past_key_values[i] = new_state
479
+
480
+ h = self.final_norm(h)
481
+
482
+ if past_key_values is not None:
483
+ past_key_values.update_seen(T)
484
+
485
+ return h, past_key_values
486
+
487
+
488
+ class OdinNextForCausalLM(OdinNextPreTrainedModel):
489
+ """Top-level wrapper exposing logits + HF generate()."""
490
+
491
+ # Map tied output -> source. Newer `transformers` (>=4.45) expects a
492
+ # dict; older versions tolerate (and used) a list of keys. Provide the
493
+ # dict form which is forward-compatible.
494
+ _tied_weights_keys = {"lm_head.weight": "model.tok_embeddings.weight"}
495
+
496
+ def __init__(self, config: OdinNextConfig):
497
+ super().__init__(config)
498
+ self.model = OdinNextModel(config)
499
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
500
+
501
+ if config.tie_embeddings:
502
+ self.lm_head.weight = self.model.tok_embeddings.weight
503
+
504
+ self.post_init()
505
+
506
+ def get_input_embeddings(self) -> nn.Embedding:
507
+ return self.model.tok_embeddings
508
+
509
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
510
+ self.model.tok_embeddings = value
511
+
512
+ def get_output_embeddings(self) -> nn.Linear:
513
+ return self.lm_head
514
+
515
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
516
+ self.lm_head = new_embeddings
517
+
518
+ def forward(
519
+ self,
520
+ input_ids: torch.Tensor,
521
+ attention_mask: Optional[torch.Tensor] = None,
522
+ past_key_values: Optional[OdinNextCache] = None,
523
+ labels: Optional[torch.Tensor] = None,
524
+ use_cache: Optional[bool] = None,
525
+ output_hidden_states: Optional[bool] = None,
526
+ return_dict: Optional[bool] = None,
527
+ **_unused,
528
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
529
+ return_dict = return_dict if return_dict is not None else True
530
+
531
+ hidden_states, past_key_values = self.model(
532
+ input_ids=input_ids,
533
+ attention_mask=attention_mask,
534
+ past_key_values=past_key_values,
535
+ use_cache=use_cache,
536
+ )
537
+ logits = self.lm_head(hidden_states)
538
+
539
+ loss = None
540
+ if labels is not None:
541
+ shift_logits = logits[..., :-1, :].contiguous()
542
+ shift_labels = labels[..., 1:].contiguous()
543
+ loss = F.cross_entropy(
544
+ shift_logits.view(-1, shift_logits.size(-1)).float(),
545
+ shift_labels.view(-1).long(),
546
+ ignore_index=-100,
547
+ )
548
+
549
+ if not return_dict:
550
+ output = (logits,) + ((past_key_values,) if past_key_values is not None else ())
551
+ return ((loss,) + output) if loss is not None else output
552
+
553
+ return CausalLMOutputWithPast(
554
+ loss=loss,
555
+ logits=logits,
556
+ past_key_values=past_key_values,
557
+ hidden_states=None,
558
+ attentions=None,
559
+ )
560
+
561
+ # -----------------------------------------------------------------
562
+ # generate() integration
563
+ # -----------------------------------------------------------------
564
+
565
+ def prepare_inputs_for_generation(
566
+ self,
567
+ input_ids: torch.Tensor,
568
+ past_key_values: Optional[OdinNextCache] = None,
569
+ attention_mask: Optional[torch.Tensor] = None,
570
+ use_cache: Optional[bool] = True,
571
+ **kwargs,
572
+ ) -> dict:
573
+ """Trim input_ids to only the new positions when a cache exists.
574
+
575
+ After the first forward, the recurrent state already encodes the
576
+ prompt. Subsequent calls only need to pass the most recently
577
+ generated token.
578
+ """
579
+ if past_key_values is not None and past_key_values.seen_tokens > 0:
580
+ # New tokens since last call.
581
+ new_count = input_ids.shape[1] - past_key_values.seen_tokens
582
+ if new_count <= 0:
583
+ # generate() can occasionally call us with the same length
584
+ # twice (e.g., assistant-decoding paths). Default to feeding
585
+ # the last token only.
586
+ input_ids = input_ids[:, -1:]
587
+ else:
588
+ input_ids = input_ids[:, -new_count:]
589
+
590
+ return {
591
+ "input_ids": input_ids,
592
+ "past_key_values": past_key_values,
593
+ "attention_mask": attention_mask,
594
+ "use_cache": use_cache,
595
+ }
596
+
597
+ def _reorder_cache(
598
+ self, past_key_values: OdinNextCache, beam_idx: torch.Tensor
599
+ ) -> OdinNextCache:
600
+ """Beam-search support: reorder per-layer states along the batch axis."""
601
+ for i, state in enumerate(past_key_values.states):
602
+ if state is not None:
603
+ past_key_values.states[i] = state.index_select(0, beam_idx.to(state.device))
604
+ return past_key_values
605
+
606
+ @staticmethod
607
+ def _supports_default_dynamic_cache() -> bool:
608
+ return False
609
+
610
+
611
+ # Re-export for convenience
612
+ __all__ = [
613
+ "OdinNextConfig",
614
+ "OdinNextModel",
615
+ "OdinNextForCausalLM",
616
+ "OdinNextCache",
617
+ ]
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "pad_token": "<|pad|>"
5
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "PreTrainedTokenizerFast",
3
+ "model_max_length": 2048,
4
+ "bos_token": "<|endoftext|>",
5
+ "eos_token": "<|endoftext|>",
6
+ "pad_token": "<|pad|>",
7
+ "clean_up_tokenization_spaces": false
8
+ }