smcleish commited on
Commit
bfbf818
·
verified ·
1 Parent(s): 238c4cb

Upload RavenForCausalLM

Browse files
config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_checkpoint_impl": "per-iteration",
3
+ "architecture_class_name": "RecurrentGPT",
4
+ "architectures": [
5
+ "RavenForCausalLM"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "raven_config_minimal.RavenConfig",
9
+ "AutoModelForCausalLM": "raven_modeling_minimal.RavenForCausalLM"
10
+ },
11
+ "bias": false,
12
+ "block_class_name": "SandwichBlock",
13
+ "block_size": 1024,
14
+ "bos_token_id": 65504,
15
+ "effective_expected_depth": 56,
16
+ "eos_token_id": 65505,
17
+ "head_dim": 128,
18
+ "init_orthogonal": false,
19
+ "init_strategy": "takase",
20
+ "init_values": {
21
+ "embed_scale": 1.0,
22
+ "embedding": 0.008703882797784892,
23
+ "out_proj": 0.0005356869554443541,
24
+ "std": 0.008703882797784892
25
+ },
26
+ "injection_type": "linear",
27
+ "intermediate_size": 8192,
28
+ "max_position_embeddings": 4096,
29
+ "mean_backprop_depth": 8,
30
+ "mean_recurrence": 8,
31
+ "mlp_class_name": "GatedMLP",
32
+ "model_type": "huginn_raven",
33
+ "n_embd": 2048,
34
+ "n_heads": 16,
35
+ "n_layers": 14,
36
+ "n_layers_in_coda": 4,
37
+ "n_layers_in_prelude": 4,
38
+ "n_layers_in_recurrent_block": 6,
39
+ "nonlin_name": "SiLU",
40
+ "norm_class_name": "RMSNorm_llama",
41
+ "norm_eps": 1e-06,
42
+ "num_key_value_heads": 16,
43
+ "pad_token_id": 65509,
44
+ "padded_vocab_size": 100352,
45
+ "padding_multiple": 4096,
46
+ "qk_bias": false,
47
+ "rope_base": 500000.0,
48
+ "rope_theta": 500000,
49
+ "sampling_scheme": "poisson-lognormal-filling",
50
+ "state_init": "like-init",
51
+ "test_time_noise": 0,
52
+ "test_time_noise_type": "fixed",
53
+ "tie_embeddings": false,
54
+ "tie_word_embeddings": false,
55
+ "torch_dtype": "float32",
56
+ "transformers_version": "4.53.1",
57
+ "vocab_size": 100352
58
+ }
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 65504,
4
+ "eos_token_id": 65505,
5
+ "pad_token_id": 65509,
6
+ "transformers_version": "4.53.1"
7
+ }
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37ade927c006425f11219a5f2d9728943c1396082bf971cb81ff6d57a4e04067
3
+ size 4614214304
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37fe3f69ac6c8319edcc7eaa7712f1117724369da1f103e90d1dbbf76e4b6e83
3
+ size 822083712
model.safetensors.index.json ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_parameters": 1359071232,
4
+ "total_size": 5436284928
5
+ },
6
+ "weight_map": {
7
+ "lm_head.weight": "model-00002-of-00002.safetensors",
8
+ "transformer.adapter.weight": "model-00001-of-00002.safetensors",
9
+ "transformer.coda.0.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
10
+ "transformer.coda.0.attn.k_norm.weight": "model-00001-of-00002.safetensors",
11
+ "transformer.coda.0.attn.proj.weight": "model-00001-of-00002.safetensors",
12
+ "transformer.coda.0.attn.q_norm.weight": "model-00001-of-00002.safetensors",
13
+ "transformer.coda.0.mlp.fc.weight": "model-00001-of-00002.safetensors",
14
+ "transformer.coda.0.mlp.proj.weight": "model-00001-of-00002.safetensors",
15
+ "transformer.coda.0.norm_1.weight": "model-00001-of-00002.safetensors",
16
+ "transformer.coda.0.norm_2.weight": "model-00001-of-00002.safetensors",
17
+ "transformer.coda.1.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
18
+ "transformer.coda.1.attn.k_norm.weight": "model-00001-of-00002.safetensors",
19
+ "transformer.coda.1.attn.proj.weight": "model-00001-of-00002.safetensors",
20
+ "transformer.coda.1.attn.q_norm.weight": "model-00001-of-00002.safetensors",
21
+ "transformer.coda.1.mlp.fc.weight": "model-00001-of-00002.safetensors",
22
+ "transformer.coda.1.mlp.proj.weight": "model-00001-of-00002.safetensors",
23
+ "transformer.coda.1.norm_1.weight": "model-00001-of-00002.safetensors",
24
+ "transformer.coda.1.norm_2.weight": "model-00001-of-00002.safetensors",
25
+ "transformer.coda.2.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
26
+ "transformer.coda.2.attn.k_norm.weight": "model-00001-of-00002.safetensors",
27
+ "transformer.coda.2.attn.proj.weight": "model-00001-of-00002.safetensors",
28
+ "transformer.coda.2.attn.q_norm.weight": "model-00001-of-00002.safetensors",
29
+ "transformer.coda.2.mlp.fc.weight": "model-00001-of-00002.safetensors",
30
+ "transformer.coda.2.mlp.proj.weight": "model-00001-of-00002.safetensors",
31
+ "transformer.coda.2.norm_1.weight": "model-00001-of-00002.safetensors",
32
+ "transformer.coda.2.norm_2.weight": "model-00001-of-00002.safetensors",
33
+ "transformer.coda.3.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
34
+ "transformer.coda.3.attn.k_norm.weight": "model-00001-of-00002.safetensors",
35
+ "transformer.coda.3.attn.proj.weight": "model-00001-of-00002.safetensors",
36
+ "transformer.coda.3.attn.q_norm.weight": "model-00001-of-00002.safetensors",
37
+ "transformer.coda.3.mlp.fc.weight": "model-00001-of-00002.safetensors",
38
+ "transformer.coda.3.mlp.proj.weight": "model-00001-of-00002.safetensors",
39
+ "transformer.coda.3.norm_1.weight": "model-00001-of-00002.safetensors",
40
+ "transformer.coda.3.norm_2.weight": "model-00001-of-00002.safetensors",
41
+ "transformer.core_block.0.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
42
+ "transformer.core_block.0.attn.k_norm.weight": "model-00001-of-00002.safetensors",
43
+ "transformer.core_block.0.attn.proj.weight": "model-00001-of-00002.safetensors",
44
+ "transformer.core_block.0.attn.q_norm.weight": "model-00001-of-00002.safetensors",
45
+ "transformer.core_block.0.mlp.fc.weight": "model-00001-of-00002.safetensors",
46
+ "transformer.core_block.0.mlp.proj.weight": "model-00001-of-00002.safetensors",
47
+ "transformer.core_block.0.norm_1.weight": "model-00001-of-00002.safetensors",
48
+ "transformer.core_block.0.norm_2.weight": "model-00001-of-00002.safetensors",
49
+ "transformer.core_block.1.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
50
+ "transformer.core_block.1.attn.k_norm.weight": "model-00001-of-00002.safetensors",
51
+ "transformer.core_block.1.attn.proj.weight": "model-00001-of-00002.safetensors",
52
+ "transformer.core_block.1.attn.q_norm.weight": "model-00001-of-00002.safetensors",
53
+ "transformer.core_block.1.mlp.fc.weight": "model-00001-of-00002.safetensors",
54
+ "transformer.core_block.1.mlp.proj.weight": "model-00001-of-00002.safetensors",
55
+ "transformer.core_block.1.norm_1.weight": "model-00001-of-00002.safetensors",
56
+ "transformer.core_block.1.norm_2.weight": "model-00001-of-00002.safetensors",
57
+ "transformer.core_block.2.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
58
+ "transformer.core_block.2.attn.k_norm.weight": "model-00001-of-00002.safetensors",
59
+ "transformer.core_block.2.attn.proj.weight": "model-00001-of-00002.safetensors",
60
+ "transformer.core_block.2.attn.q_norm.weight": "model-00001-of-00002.safetensors",
61
+ "transformer.core_block.2.mlp.fc.weight": "model-00001-of-00002.safetensors",
62
+ "transformer.core_block.2.mlp.proj.weight": "model-00001-of-00002.safetensors",
63
+ "transformer.core_block.2.norm_1.weight": "model-00001-of-00002.safetensors",
64
+ "transformer.core_block.2.norm_2.weight": "model-00001-of-00002.safetensors",
65
+ "transformer.core_block.3.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
66
+ "transformer.core_block.3.attn.k_norm.weight": "model-00001-of-00002.safetensors",
67
+ "transformer.core_block.3.attn.proj.weight": "model-00001-of-00002.safetensors",
68
+ "transformer.core_block.3.attn.q_norm.weight": "model-00001-of-00002.safetensors",
69
+ "transformer.core_block.3.mlp.fc.weight": "model-00001-of-00002.safetensors",
70
+ "transformer.core_block.3.mlp.proj.weight": "model-00001-of-00002.safetensors",
71
+ "transformer.core_block.3.norm_1.weight": "model-00001-of-00002.safetensors",
72
+ "transformer.core_block.3.norm_2.weight": "model-00001-of-00002.safetensors",
73
+ "transformer.core_block.4.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
74
+ "transformer.core_block.4.attn.k_norm.weight": "model-00001-of-00002.safetensors",
75
+ "transformer.core_block.4.attn.proj.weight": "model-00001-of-00002.safetensors",
76
+ "transformer.core_block.4.attn.q_norm.weight": "model-00001-of-00002.safetensors",
77
+ "transformer.core_block.4.mlp.fc.weight": "model-00001-of-00002.safetensors",
78
+ "transformer.core_block.4.mlp.proj.weight": "model-00001-of-00002.safetensors",
79
+ "transformer.core_block.4.norm_1.weight": "model-00001-of-00002.safetensors",
80
+ "transformer.core_block.4.norm_2.weight": "model-00001-of-00002.safetensors",
81
+ "transformer.core_block.5.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
82
+ "transformer.core_block.5.attn.k_norm.weight": "model-00001-of-00002.safetensors",
83
+ "transformer.core_block.5.attn.proj.weight": "model-00001-of-00002.safetensors",
84
+ "transformer.core_block.5.attn.q_norm.weight": "model-00001-of-00002.safetensors",
85
+ "transformer.core_block.5.mlp.fc.weight": "model-00001-of-00002.safetensors",
86
+ "transformer.core_block.5.mlp.proj.weight": "model-00001-of-00002.safetensors",
87
+ "transformer.core_block.5.norm_1.weight": "model-00001-of-00002.safetensors",
88
+ "transformer.core_block.5.norm_2.weight": "model-00001-of-00002.safetensors",
89
+ "transformer.ln_f.weight": "model-00001-of-00002.safetensors",
90
+ "transformer.prelude.0.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
91
+ "transformer.prelude.0.attn.k_norm.weight": "model-00001-of-00002.safetensors",
92
+ "transformer.prelude.0.attn.proj.weight": "model-00001-of-00002.safetensors",
93
+ "transformer.prelude.0.attn.q_norm.weight": "model-00001-of-00002.safetensors",
94
+ "transformer.prelude.0.mlp.fc.weight": "model-00001-of-00002.safetensors",
95
+ "transformer.prelude.0.mlp.proj.weight": "model-00001-of-00002.safetensors",
96
+ "transformer.prelude.0.norm_1.weight": "model-00001-of-00002.safetensors",
97
+ "transformer.prelude.0.norm_2.weight": "model-00001-of-00002.safetensors",
98
+ "transformer.prelude.1.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
99
+ "transformer.prelude.1.attn.k_norm.weight": "model-00001-of-00002.safetensors",
100
+ "transformer.prelude.1.attn.proj.weight": "model-00001-of-00002.safetensors",
101
+ "transformer.prelude.1.attn.q_norm.weight": "model-00001-of-00002.safetensors",
102
+ "transformer.prelude.1.mlp.fc.weight": "model-00001-of-00002.safetensors",
103
+ "transformer.prelude.1.mlp.proj.weight": "model-00001-of-00002.safetensors",
104
+ "transformer.prelude.1.norm_1.weight": "model-00001-of-00002.safetensors",
105
+ "transformer.prelude.1.norm_2.weight": "model-00001-of-00002.safetensors",
106
+ "transformer.prelude.2.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
107
+ "transformer.prelude.2.attn.k_norm.weight": "model-00001-of-00002.safetensors",
108
+ "transformer.prelude.2.attn.proj.weight": "model-00001-of-00002.safetensors",
109
+ "transformer.prelude.2.attn.q_norm.weight": "model-00001-of-00002.safetensors",
110
+ "transformer.prelude.2.mlp.fc.weight": "model-00001-of-00002.safetensors",
111
+ "transformer.prelude.2.mlp.proj.weight": "model-00001-of-00002.safetensors",
112
+ "transformer.prelude.2.norm_1.weight": "model-00001-of-00002.safetensors",
113
+ "transformer.prelude.2.norm_2.weight": "model-00001-of-00002.safetensors",
114
+ "transformer.prelude.3.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
115
+ "transformer.prelude.3.attn.k_norm.weight": "model-00001-of-00002.safetensors",
116
+ "transformer.prelude.3.attn.proj.weight": "model-00001-of-00002.safetensors",
117
+ "transformer.prelude.3.attn.q_norm.weight": "model-00001-of-00002.safetensors",
118
+ "transformer.prelude.3.mlp.fc.weight": "model-00001-of-00002.safetensors",
119
+ "transformer.prelude.3.mlp.proj.weight": "model-00001-of-00002.safetensors",
120
+ "transformer.prelude.3.norm_1.weight": "model-00001-of-00002.safetensors",
121
+ "transformer.prelude.3.norm_2.weight": "model-00001-of-00002.safetensors",
122
+ "transformer.wte.weight": "model-00001-of-00002.safetensors"
123
+ }
124
+ }
raven_config_minimal.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A HuggingFace-style model configuration."""
2
+
3
+ from transformers import PretrainedConfig
4
+ from math import sqrt
5
+
6
+
7
+ class RavenConfig(PretrainedConfig):
8
+ model_type = "huginn_raven"
9
+ keys_to_ignore_at_inference = [""]
10
+ attribute_map = {"num_attention_heads": "n_heads", "hidden_size": "n_embd", "num_hidden_layers": "n_layers"}
11
+
12
+ def __init__(
13
+ self,
14
+ n_embd: int = 5280,
15
+ n_heads: int = 55,
16
+ n_layers: int = 8, # total of prelude + recurrent + coda
17
+ block_size: int = 4096,
18
+ vocab_size: int = 65536,
19
+ padding_multiple: int = 4096,
20
+ tie_embeddings: bool = True,
21
+ intermediate_size: int = 17920,
22
+ bias: bool = False,
23
+ architecture_class_name: str = "RecurrentGPT",
24
+ block_class_name: str = "SandwichBlock",
25
+ norm_class_name: str = "RMSNorm_llama",
26
+ norm_eps: float = 0.000001,
27
+ mlp_class_name: str = "GatedMLP",
28
+ nonlin_name: str = "SiLU",
29
+ init_strategy: str = "takase",
30
+ init_orthogonal: bool = False,
31
+ state_init: str = "like-init",
32
+ injection_type: str = "linear",
33
+ n_layers_in_recurrent_block: int = 4,
34
+ mean_recurrence: int = 32,
35
+ sampling_scheme: str = "poisson-lognormal-filling",
36
+ mean_backprop_depth: int = 8,
37
+ n_layers_in_prelude: int = 2,
38
+ n_layers_in_coda: int = 2,
39
+ qk_bias: bool = True,
40
+ activation_checkpoint_impl: str = "per-iteration",
41
+ rope_base: float = 50_000,
42
+ torch_dtype: str = "bfloat16",
43
+ transformers_version: str = "4.47.1",
44
+ **kwargs,
45
+ ):
46
+ self.n_embd = n_embd
47
+ self.n_heads = n_heads
48
+ self.n_layers = n_layers
49
+ self.block_size = block_size
50
+ self.vocab_size = self.padded_vocab_size = vocab_size
51
+ self.padding_multiple = padding_multiple
52
+ self.tie_embeddings = tie_embeddings
53
+ self.intermediate_size = intermediate_size
54
+ self.bias = bias
55
+ self.architecture_class_name = architecture_class_name
56
+ self.block_class_name = block_class_name
57
+ self.norm_class_name = norm_class_name
58
+ self.norm_eps = norm_eps
59
+ self.mlp_class_name = mlp_class_name
60
+ self.nonlin_name = nonlin_name
61
+ self.init_strategy = init_strategy
62
+ self.init_orthogonal = init_orthogonal
63
+ self.state_init = state_init
64
+ self.injection_type = injection_type
65
+ self.n_layers_in_recurrent_block = n_layers_in_recurrent_block
66
+ self.mean_recurrence = mean_recurrence
67
+ self.sampling_scheme = sampling_scheme
68
+ self.mean_backprop_depth = mean_backprop_depth
69
+ self.n_layers_in_prelude = n_layers_in_prelude
70
+ self.n_layers_in_coda = n_layers_in_coda
71
+ self.qk_bias = qk_bias
72
+ self.activation_checkpoint_impl = activation_checkpoint_impl
73
+ self.rope_base = rope_base
74
+ self.torch_dtype = torch_dtype # Added from JSON
75
+ self.transformers_version = transformers_version # Added from JSON
76
+ # inference
77
+ self.test_time_noise = 0
78
+ self.test_time_noise_type = "fixed"
79
+ # Derived
80
+ self.num_key_value_heads = n_heads
81
+ self.num_attention_heads = n_heads
82
+ self.head_dim = n_embd // n_heads
83
+ self.effective_expected_depth = (
84
+ self.n_layers_in_prelude + self.n_layers_in_coda + self.n_layers_in_recurrent_block * self.mean_recurrence
85
+ )
86
+ self.init_values = {
87
+ "std": sqrt(2 / (5 * self.n_embd)),
88
+ "out_proj": sqrt(2 / (5 * self.n_embd)) / sqrt(2 * self.effective_expected_depth),
89
+ "embedding": sqrt(2 / (5 * self.n_embd)),
90
+ "embed_scale": sqrt(self.n_embd),
91
+ }
92
+
93
+ super().__init__(
94
+ # pad_token_id=65509,
95
+ # bos_token_id=65504,
96
+ # eos_token_id=65505,
97
+ tie_word_embeddings=tie_embeddings,
98
+ **kwargs,
99
+ )
raven_modeling_minimal.py ADDED
@@ -0,0 +1,1579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modeling file for HF compatibility and zero-shot experiments."""
2
+
3
+ import torch
4
+ import math
5
+
6
+ from torch import Tensor
7
+ from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention
8
+ from torch.nn.attention import bias as attn_bias
9
+ from dataclasses import dataclass
10
+ from typing import Union, Optional, Any
11
+
12
+
13
+ from .raven_config_minimal import RavenConfig
14
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
15
+
16
+ ###################### Huggingface Glue code I ##################################################################
17
+ from transformers import PreTrainedModel, GenerationMixin
18
+ from transformers.utils import ModelOutput
19
+ from transformers.generation.utils import GenerateDecoderOnlyOutput
20
+
21
+ import torch.nn.functional as F
22
+ from transformers import GenerationConfig
23
+ from transformers.models.olmo2.modeling_olmo2 import Olmo2RotaryEmbedding, apply_rotary_pos_emb
24
+
25
+ torch.backends.cuda.enable_math_sdp(False)
26
+
27
+
28
+ class RavenPreTrainedModel(PreTrainedModel):
29
+ config_class = RavenConfig
30
+ base_model_prefix = "model"
31
+ supports_gradient_checkpointing = True
32
+ _no_split_modules = ["SandwichBlock"]
33
+ _skip_keys_device_placement = ["past_key_values"]
34
+ _tied_weights_keys = ["lm_head.weight"]
35
+ _supports_flash_attn_2 = True
36
+ _supports_sdpa = True
37
+ _supports_cache_class = True
38
+ _supports_quantized_cache = False
39
+ _supports_static_cache = True
40
+ _tp_plan = {}
41
+
42
+ def _init_weights(self, module):
43
+ if not torch.rand((1,)).is_meta:
44
+ print("Random Initialization not implemented.")
45
+
46
+
47
+ @dataclass
48
+ class CausalLMOutputRecurrentLatents(ModelOutput):
49
+ loss: Optional[torch.Tensor] = None
50
+ log_ppl: Optional[torch.Tensor] = None
51
+ logits: Optional[torch.Tensor] = None
52
+ past_key_values: Optional[Cache] = None
53
+ latent_states: Optional[torch.Tensor] = None
54
+ hidden_states: Optional[torch.Tensor] = None
55
+ attention_maps: Optional[dict[int, torch.Tensor]] = None
56
+ stats: Optional[dict] = None
57
+
58
+
59
+ ###################### Minimal implementation from here ############################################################
60
+
61
+
62
+ class RMSNorm(torch.nn.Module):
63
+ """Saner dtype handling and slightly better for fusion"""
64
+
65
+ def __init__(self, dim: int, eps: float = 1e-6):
66
+ super().__init__()
67
+ self.eps = eps
68
+ self.weight = torch.nn.Parameter(torch.ones(dim))
69
+
70
+ def _norm(self, x):
71
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
72
+
73
+ def forward(self, x):
74
+ with torch.autocast(enabled=False, device_type=x.device.type if x.device.type != "meta" else "cuda"):
75
+ return self._norm(x.float()).type_as(x) * self.weight
76
+
77
+ def reset_parameters(self) -> None:
78
+ torch.nn.init.ones_(self.weight)
79
+
80
+
81
+ class HuginnDynamicCache(DynamicCache):
82
+ def __init__(self, lookup_strategy: str = "full") -> None:
83
+ super().__init__()
84
+ self._seen_tokens = 0
85
+ self.key_cache: dict[int, dict[int, torch.Tensor]] = {}
86
+ self.value_cache: dict[int, dict[int, torch.Tensor]] = {}
87
+ # structure: cache[index_of_layer_or_recurrent_step][index_in_sequence]
88
+ # the cache is held uncoalesced because certain recurrent steps may be missing for some sequence ids if using
89
+ # per-token adaptive compute. In those cases, the "lookup_strategy" determines how to proceed
90
+ # Also, It is critical that the head indices do not overlap with the recurrent iteration indices
91
+ self.lookup_strategy = lookup_strategy
92
+
93
+ def update(
94
+ self,
95
+ key_states: torch.Tensor,
96
+ value_states: torch.Tensor,
97
+ step_idx_tensor: torch.Tensor,
98
+ lookup_strategy: Optional[str] = None,
99
+ ) -> tuple[torch.Tensor, torch.Tensor]:
100
+ step_idx: int = int(step_idx_tensor) # todo: fix dicts with tensor step_idx, currently the memberships fail
101
+ lookup_strategy = self.lookup_strategy if lookup_strategy is None else lookup_strategy
102
+ if "compress-" in self.lookup_strategy and step_idx > 1: # hardcode for current model!
103
+ if "compress-s" in self.lookup_strategy:
104
+ compression_stage = int(self.lookup_strategy.split("compress-")[1][1:])
105
+ new_step_idx = (step_idx - 2) % compression_stage + 2
106
+ elif "compress-anchor" in self.lookup_strategy:
107
+ if step_idx - 2 < 4 * 8: # anchor onto first 8 recurrence steps # noqa: SIM108
108
+ new_step_idx = step_idx
109
+ else: # then re-use the next 4 KV states = one recurrence for all future recurrence
110
+ new_step_idx = 34 + (step_idx - 34) % 4
111
+ # print(step_idx, new_step_idx)
112
+ else: # compress-r
113
+ compression_stage = int(self.lookup_strategy.split("compress-")[1][1:])
114
+ new_step_idx = (step_idx - 2) // compression_stage + 2
115
+ step_idx = new_step_idx
116
+ # Init
117
+ if step_idx not in self.key_cache:
118
+ self.key_cache[step_idx] = {}
119
+ self.value_cache[step_idx] = {}
120
+ # Update the number of seen tokens, we assume that step_idx=0 (first prelude) is always hit
121
+ if step_idx == 0:
122
+ self._seen_tokens += key_states.shape[-2]
123
+ # Add entries to cache
124
+ for idx, entry in enumerate(key_states.unbind(dim=-2)):
125
+ if "compress-" not in self.lookup_strategy:
126
+ assert step_idx < 0 or self._seen_tokens - key_states.shape[-2] + idx not in self.key_cache[step_idx]
127
+ self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry
128
+ for idx, entry in enumerate(value_states.unbind(dim=-2)):
129
+ self.value_cache[step_idx][self._seen_tokens - value_states.shape[-2] + idx] = entry
130
+
131
+ # Materialize past state based on lookup strategy:
132
+ if len(self.key_cache[step_idx]) == self._seen_tokens or self.lookup_strategy == "full":
133
+ # All entries are present, materialize cache as normal
134
+ return (
135
+ torch.stack(list(self.key_cache[step_idx].values()), dim=-2),
136
+ torch.stack(list(self.value_cache[step_idx].values()), dim=-2),
137
+ )
138
+ else: # some entries were not previously computed
139
+ if lookup_strategy.startswith("latest-m4"):
140
+ latest_keys = []
141
+ latest_values = []
142
+ for token_pos in range(self._seen_tokens):
143
+ # For steps >= 2, use modulo 4, this hard-codes the huginn block structure for now
144
+ if step_idx >= 2:
145
+ # Find valid steps for this token position
146
+ valid_steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]]
147
+ max_step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
148
+ else:
149
+ max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
150
+ latest_keys.append(self.key_cache[max_step][token_pos])
151
+ latest_values.append(self.value_cache[max_step][token_pos])
152
+ return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
153
+ elif lookup_strategy.startswith("available-m4"):
154
+ latest_keys = []
155
+ latest_values = []
156
+ for token_pos in range(self._seen_tokens):
157
+ if token_pos in self.key_cache[step_idx]:
158
+ step = step_idx
159
+ else:
160
+ # Find valid steps for this token position
161
+ valid_steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]]
162
+ step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
163
+ latest_keys.append(self.key_cache[step][token_pos])
164
+ latest_values.append(self.value_cache[step][token_pos])
165
+ return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
166
+ elif lookup_strategy.startswith("always-last-m4"):
167
+ latest_keys = []
168
+ latest_values = []
169
+ for token_pos in range(self._seen_tokens):
170
+ # For steps >= 2, use modulo 4, this hard-codes the huginn block structure for now
171
+ if step_idx >= 2:
172
+ # Find valid steps for this token position
173
+ valid_steps = [key_step for key_step in self.key_cache if token_pos in self.key_cache[key_step]]
174
+ max_step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
175
+ else:
176
+ max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
177
+ latest_keys.append(self.key_cache[max_step][token_pos])
178
+ latest_values.append(self.value_cache[max_step][token_pos])
179
+ return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
180
+ elif lookup_strategy.startswith("skip"):
181
+ existing_keys = []
182
+ existing_values = []
183
+ for token_pos in range(self._seen_tokens):
184
+ if token_pos in self.key_cache[step_idx]:
185
+ existing_keys.append(self.key_cache[step_idx][token_pos])
186
+ existing_values.append(self.value_cache[step_idx][token_pos])
187
+ return torch.stack(existing_keys, dim=-2), torch.stack(existing_values, dim=-2)
188
+ elif lookup_strategy.startswith("randomized"): # sanity check
189
+ rand_keys = []
190
+ rand_values = []
191
+ for token_pos in range(self._seen_tokens):
192
+ if step_idx < 2: # For prelude steps
193
+ max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
194
+ else: # Get all steps from same block position
195
+ curr_modulo = (step_idx - 2) % 4 + 2
196
+ valid_steps = [
197
+ s
198
+ for s in range(2, step_idx + 1)
199
+ if (s - 2) % 4 + 2 == curr_modulo and token_pos in self.key_cache[s]
200
+ ]
201
+ max_step = valid_steps[torch.randint(len(valid_steps), (1,))]
202
+ rand_keys.append(self.key_cache[max_step][token_pos])
203
+ rand_values.append(self.value_cache[max_step][token_pos])
204
+ return torch.stack(rand_keys, dim=-2), torch.stack(rand_values, dim=-2)
205
+ else:
206
+ raise ValueError(f"Unknown lookup strategy: {lookup_strategy}")
207
+
208
+ def reset(self) -> None:
209
+ """Reset the cache state."""
210
+ self._seen_tokens = 0
211
+ self.key_cache.clear()
212
+ self.value_cache.clear()
213
+
214
+ def clear_last_k_entries(self, k: int = 0):
215
+ """Partially clear cache."""
216
+ assert self._seen_tokens >= k
217
+ self._seen_tokens = self._seen_tokens - k
218
+ # self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry
219
+ self.key_cache = {
220
+ step: {seq: seq_cache for seq, seq_cache in cache.items() if seq < self._seen_tokens}
221
+ for step, cache in self.key_cache.items()
222
+ }
223
+ self.value_cache = {
224
+ step: {seq: seq_cache for seq, seq_cache in cache.items() if seq < self._seen_tokens}
225
+ for step, cache in self.value_cache.items()
226
+ }
227
+
228
+ def get_seq_length(self, step_idx: int = 0) -> int:
229
+ return self._seen_tokens
230
+
231
+ def get_memory_usage(self) -> float:
232
+ total_bytes = 0
233
+ # For each recurrent step/layer index
234
+ for step_idx in self.key_cache:
235
+ # Get the sequence cache for this step
236
+ key_seq_cache = self.key_cache[step_idx]
237
+ for seq_idx in key_seq_cache:
238
+ key_tensor = key_seq_cache[seq_idx]
239
+ # Add memory for of key tensors, assuming value is the same
240
+ total_bytes += key_tensor.nelement() * key_tensor.element_size()
241
+ return total_bytes * 2 / (1024 * 1024)
242
+
243
+
244
+ class HuginnStaticCache(Cache):
245
+ """Static Cache for the recurrent model"""
246
+
247
+ is_compileable = False # this is todo
248
+
249
+ def __init__(
250
+ self,
251
+ max_length: int,
252
+ max_num_steps: int,
253
+ num_heads: int,
254
+ hidden_dim: int,
255
+ batch_size: int = 1,
256
+ lookup_strategy: str = "full",
257
+ device: Optional[Union[torch.device, str]] = None,
258
+ dtype: torch.dtype = torch.float32,
259
+ ) -> None:
260
+ super().__init__()
261
+ self._seen_tokens = 0
262
+ self.max_length = max_length
263
+ self.lookup_strategy = lookup_strategy
264
+
265
+ # Adjust max_num_steps based on compression strategy
266
+ if "compress-" in lookup_strategy:
267
+ compression_stage = int(lookup_strategy.split("compress-")[1][1:])
268
+ if "compress-s" in lookup_strategy:
269
+ # For modulo compression (s), we need steps for 0,1 + compressed steps
270
+ self.max_num_steps = 4 + compression_stage
271
+ else:
272
+ # For relative compression, we need steps for 0,1 + compressed steps
273
+ self.max_num_steps = 4 + (max_num_steps - 4 + compression_stage - 1) // compression_stage
274
+ else:
275
+ self.max_num_steps = max_num_steps
276
+
277
+ # Pre-allocate cache tensors [steps, batch, heads, seq_len, head_dim]
278
+ device = torch.device(device) if device is not None else None
279
+ cache_shape = (self.max_num_steps, batch_size, num_heads, max_length, hidden_dim)
280
+
281
+ self.key_cache = torch.zeros(cache_shape, dtype=dtype, device=device)
282
+ self.value_cache = torch.zeros(cache_shape, dtype=dtype, device=device)
283
+ self.valid_mask = torch.zeros((self.max_num_steps, max_length), dtype=torch.bool, device=device)
284
+ # Mark tensors as static for compile
285
+ torch._dynamo.mark_static_address(self.key_cache)
286
+ torch._dynamo.mark_static_address(self.value_cache)
287
+ torch._dynamo.mark_static_address(self.valid_mask)
288
+
289
+ def update(
290
+ self,
291
+ key_states: torch.Tensor,
292
+ value_states: torch.Tensor,
293
+ step_idx: torch.Tensor,
294
+ lookup_strategy: Optional[str] = None,
295
+ ) -> tuple[torch.Tensor, torch.Tensor]:
296
+ if step_idx == 0:
297
+ self._seen_tokens += key_states.shape[-2]
298
+
299
+ # Adjust step_idx for compression
300
+ lookup_strategy = self.lookup_strategy if lookup_strategy is None else lookup_strategy
301
+ if "compress-" in lookup_strategy and step_idx > 1:
302
+ compression_stage = int(lookup_strategy.split("compress-")[1][1:])
303
+ if "compress-s" in lookup_strategy:
304
+ step_idx = (step_idx - 2) % compression_stage + 2
305
+ else:
306
+ step_idx = (step_idx - 2) // compression_stage + 2
307
+
308
+ start_idx = self._seen_tokens - key_states.shape[-2]
309
+
310
+ indices = torch.arange(start_idx, start_idx + key_states.shape[-2], device=key_states.device)
311
+ self.key_cache[step_idx].index_copy_(2, indices, key_states)
312
+ self.value_cache[step_idx].index_copy_(2, indices, value_states)
313
+ self.valid_mask[step_idx, start_idx : start_idx + key_states.shape[-2]] = True
314
+
315
+ # Return based on lookup strategy
316
+ if lookup_strategy == "full":
317
+ return (
318
+ self.key_cache[step_idx, :, :, : self._seen_tokens],
319
+ self.value_cache[step_idx, :, :, : self._seen_tokens],
320
+ )
321
+ elif lookup_strategy.startswith("latest-m4"):
322
+ if step_idx >= 2:
323
+ pattern_steps = torch.arange(2, step_idx.item() + 1, 4, device=self.valid_mask.device)
324
+ pattern_valid = self.valid_mask[pattern_steps]
325
+ max_valid_step = pattern_steps[pattern_valid.to(torch.long).argmax(dim=0)]
326
+ return (
327
+ self.key_cache[max_valid_step, torch.arange(self._seen_tokens)],
328
+ self.value_cache[max_valid_step, torch.arange(self._seen_tokens)],
329
+ )
330
+ return self.key_cache[step_idx, :, :, : self._seen_tokens], self.value_cache[
331
+ step_idx, :, :, : self._seen_tokens
332
+ ]
333
+ elif lookup_strategy == "skip":
334
+ valid_mask = self.valid_mask[step_idx, : self._seen_tokens]
335
+ return (
336
+ self.key_cache[step_idx, :, :, : self._seen_tokens][valid_mask],
337
+ self.value_cache[step_idx, :, :, : self._seen_tokens][valid_mask],
338
+ )
339
+ elif lookup_strategy.startswith("randomized"):
340
+ if step_idx < 2:
341
+ max_step = step_idx
342
+ else:
343
+ curr_modulo = (step_idx - 2) % 4 + 2
344
+ valid_steps = (
345
+ torch.where(
346
+ (torch.arange(2, step_idx.item() + 1, device=self.valid_mask.device) - 2) % 4 + 2 == curr_modulo
347
+ )[0]
348
+ + 2
349
+ )
350
+ rand_idx = torch.randint(len(valid_steps), (1,), device=valid_steps.device)
351
+ max_step = valid_steps[rand_idx]
352
+ return self.key_cache[max_step, : self._seen_tokens], self.value_cache[max_step, : self._seen_tokens]
353
+ else:
354
+ raise ValueError(f"Unknown lookup strategy: {lookup_strategy}")
355
+
356
+ def reset(self) -> None:
357
+ self._seen_tokens = 0
358
+ self.key_cache.zero_()
359
+ self.value_cache.zero_()
360
+ self.valid_mask.zero_()
361
+
362
+ def get_seq_length(self, step_idx: int = 0) -> int:
363
+ return self._seen_tokens
364
+
365
+ def get_memory_usage(self) -> float:
366
+ return (self.key_cache.nelement() + self.value_cache.nelement()) * self.key_cache.element_size() / (1024 * 1024)
367
+
368
+
369
+ ValidCache = HuginnDynamicCache | HuginnStaticCache
370
+
371
+
372
+ class CausalSelfAttention(torch.nn.Module):
373
+ def __init__(self, config: RavenConfig) -> None:
374
+ super().__init__()
375
+ self.config = config
376
+ self.n_head = config.num_attention_heads
377
+ self.n_kv_heads = config.num_key_value_heads
378
+ self.head_dim = getattr(config, "head_dim", config.n_embd // self.n_head)
379
+
380
+ shape = (self.n_head + 2 * self.n_kv_heads) * self.head_dim
381
+ self.chunks = [self.n_head * self.head_dim, self.n_kv_heads * self.head_dim, self.n_kv_heads * self.head_dim]
382
+
383
+ self.Wqkv = torch.nn.Linear(config.n_embd, shape, bias=False)
384
+ if config.qk_bias:
385
+ self.qk_bias = torch.nn.Parameter(torch.zeros(2, 1, self.n_head, self.head_dim))
386
+ self.q_norm = RMSNorm(config.num_attention_heads * config.head_dim, eps=config.norm_eps) # unlike olmo, only on the head dim!
387
+ self.k_norm = RMSNorm(config.num_key_value_heads * config.head_dim, eps=config.norm_eps) # thus post q_norm does not need reshape
388
+ self.proj = torch.nn.Linear(self.n_head * self.head_dim, config.n_embd, bias=False)
389
+
390
+ def forward(
391
+ self,
392
+ x: Tensor,
393
+ freqs_cis: Tensor,
394
+ block_idx: torch.Tensor,
395
+ mask: Optional[BlockMask] = None,
396
+ past_key_values: Optional[ValidCache] = None,
397
+ ) -> Tensor:
398
+ B, S, E = x.shape # batch size, sequence length, embedding dimensionality (n_embd)
399
+ q, k, v = self.Wqkv(x).split(self.chunks, dim=2)
400
+
401
+ q = self.q_norm(q)
402
+ k = self.k_norm(k)
403
+
404
+ q = q.view(B, S, self.n_head, self.head_dim)
405
+ k = k.view(B, S, self.n_kv_heads, self.head_dim)
406
+ v = v.view(B, S, self.n_kv_heads, self.head_dim)
407
+ # bias?
408
+ if self.config.qk_bias:
409
+ q_bias, k_bias = self.qk_bias.split(1, dim=0)
410
+ q, k = (q + q_bias).to(q.dtype), (k + k_bias).to(q.dtype)
411
+
412
+ q = q.transpose(1, 2) # (B, nh, S, hs)
413
+ k = k.transpose(1, 2)
414
+ v = v.transpose(1, 2)
415
+
416
+ # apply rotary
417
+ cos, sin = freqs_cis
418
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
419
+
420
+ if past_key_values is not None:
421
+ k, v = past_key_values.update(k, v, block_idx)
422
+
423
+ if mask is not None:
424
+ y: torch.Tensor = flex_attention(q, k, v, block_mask=mask) # type: ignore
425
+ else:
426
+ if q.shape[2] < k.shape[2]:
427
+ if q.shape[2] > 1:
428
+ bias = attn_bias.causal_lower_right(q.shape[2], k.shape[2])
429
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, bias, dropout_p=0.0, enable_gqa=True)
430
+ else:
431
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False, enable_gqa=True)
432
+ else:
433
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True, enable_gqa=True)
434
+ y = y.transpose(1, 2).reshape(B, S, self.n_head * self.head_dim).contiguous() # reshape is a view if possible (it mostly is)
435
+ return self.proj(y)
436
+
437
+
438
+ class GatedMLP(torch.nn.Module):
439
+ def __init__(self, config: RavenConfig, in_features: int = 0) -> None:
440
+ super().__init__()
441
+ in_features = config.n_embd if in_features == 0 else in_features
442
+ self.fc = torch.nn.Linear(in_features, config.intermediate_size * 2, bias=False)
443
+
444
+ self.proj = torch.nn.Linear(config.intermediate_size, config.n_embd, bias=False)
445
+ self.nonlin = torch.nn.SiLU()
446
+
447
+ def forward(self, x: Tensor) -> Tensor:
448
+ # modified to single FC layer to improve parallelism
449
+ x_fc_1, x_fc_2 = self.fc(x).chunk(2, dim=-1)
450
+ x = self.nonlin(x_fc_1) * x_fc_2
451
+ return self.proj(x)
452
+
453
+
454
+ class SandwichBlock(torch.nn.Module):
455
+ expanded = False
456
+
457
+ def __init__(self, config: RavenConfig, layer_id: int) -> None:
458
+ super().__init__()
459
+ self.norm_1 = RMSNorm(config.n_embd, eps=config.norm_eps)
460
+ self.attn = CausalSelfAttention(config)
461
+ self.norm_2 = RMSNorm(config.n_embd, eps=config.norm_eps)
462
+ self.mlp = GatedMLP(config)
463
+ self.layer_id = layer_id
464
+
465
+ def forward(
466
+ self,
467
+ x: Tensor,
468
+ freqs_cis: Tensor,
469
+ step_idx: int,
470
+ mask: Optional[BlockMask] = None,
471
+ past_key_values: Optional[ValidCache] = None,
472
+ ) -> Tensor:
473
+ attn_out = self.norm_1(self.attn(x, freqs_cis, step_idx, mask, past_key_values))
474
+ x = attn_out + x
475
+ x = self.norm_2(self.mlp(x)) + x
476
+ return x
477
+
478
+
479
+ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
480
+
481
+ def __init__(
482
+ self,
483
+ config: RavenConfig,
484
+ ) -> None:
485
+ super().__init__(config)
486
+ self.config = config
487
+
488
+ # Transformer layers
489
+ prelude = torch.nn.ModuleList(SandwichBlock(config, layer_id=i) for i in range(config.n_layers_in_prelude))
490
+ adapter = torch.nn.Linear(config.n_embd * 2, config.n_embd, bias=config.bias)
491
+ core_block = torch.nn.ModuleList(
492
+ SandwichBlock(config, layer_id=i + config.n_layers_in_prelude)
493
+ for i in range(config.n_layers_in_recurrent_block)
494
+ )
495
+ o = config.n_layers_in_prelude + config.n_layers_in_recurrent_block * config.mean_recurrence
496
+ coda = torch.nn.ModuleList(SandwichBlock(config, layer_id=i + o) for i in range(config.n_layers_in_coda))
497
+
498
+ self.transformer = torch.nn.ModuleDict(
499
+ dict(
500
+ wte=torch.nn.Embedding(config.padded_vocab_size, config.n_embd),
501
+ prelude=prelude,
502
+ adapter=adapter,
503
+ core_block=core_block,
504
+ coda=coda,
505
+ ln_f=RMSNorm(config.n_embd, eps=config.norm_eps), # used twice :>
506
+ )
507
+ )
508
+ self.emb_scale = config.init_values["embed_scale"]
509
+ # Head
510
+ self.lm_head = torch.nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
511
+ if self.config.tie_embeddings:
512
+ self.tie_weights()
513
+ # rope
514
+ self.rotary_emb = Olmo2RotaryEmbedding(config=config)
515
+
516
+ def get_input_embeddings(self):
517
+ return self.transformer.wte
518
+
519
+ def get_output_embeddings(self):
520
+ return self.lm_head
521
+
522
+
523
+ def compile_mask(
524
+ self,
525
+ input_ids: torch.Tensor,
526
+ attention_mask: Optional[torch.Tensor] = None,
527
+ past_key_values: Optional[ValidCache] = None,
528
+ pad_token_id=65509,
529
+ ) -> Optional[BlockMask]:
530
+ batch_size, seq_len = input_ids.shape[0], input_ids.shape[1]
531
+
532
+ # If no padding and no attention mask, no need for a mask
533
+ if attention_mask is None and (input_ids == pad_token_id).sum() == 0:
534
+ return None
535
+
536
+ if past_key_values is not None and seq_len == 1:
537
+ return None
538
+
539
+ # Get total sequence length including cache
540
+ cache_len = past_key_values.get_seq_length() if past_key_values is not None else 0
541
+ kv_length = cache_len + seq_len
542
+
543
+ if attention_mask is None:
544
+
545
+ def mask_mod(b, h, q_idx, kv_idx):
546
+ return q_idx >= kv_idx & (input_ids[b, kv_idx] != pad_token_id)
547
+ else:
548
+
549
+ def mask_mod(b, h, q_idx, kv_idx):
550
+ return (q_idx >= kv_idx) & (input_ids[b, kv_idx] != pad_token_id) & attention_mask[b, q_idx, kv_idx]
551
+
552
+ kv_length = past_key_values.get_seq_length() if past_key_values is not None else seq_len
553
+ if kv_length == 0:
554
+ kv_length = seq_len # prefill
555
+ block_mask = create_block_mask(
556
+ mask_mod,
557
+ B=batch_size,
558
+ H=None,
559
+ Q_LEN=seq_len,
560
+ KV_LEN=kv_length,
561
+ device=input_ids.device,
562
+ )
563
+
564
+ # # Define mask_mod function
565
+ # def mask_mod(b, h, q_idx, kv_idx):
566
+ # # Always apply causal constraint
567
+ # is_causal = q_idx >= kv_idx
568
+
569
+ # # Handle cache vs current tokens
570
+ # is_cache = kv_idx < cache_len
571
+ # current_idx = kv_idx - cache_len
572
+
573
+ # # For cache: always valid; For current: check padding
574
+ # not_pad = input_ids[b, current_idx] != pad_token_id
575
+ # valid = is_cache | not_pad
576
+
577
+ # # Apply attention mask if provided
578
+ # if attention_mask is not None:
579
+ # q_idx_curr = q_idx - cache_len
580
+ # attn_valid = attention_mask[b, q_idx_curr, current_idx]
581
+ # valid = valid & (is_cache | attn_valid)
582
+
583
+ # return is_causal & valid
584
+
585
+ # def mask_mod(b, h, q_idx, kv_idx):
586
+ # is_causal = q_idx >= kv_idx
587
+ # is_current = (kv_idx >= cache_len) & (kv_idx < kv_length)
588
+ # current_idx = kv_idx - cache_len
589
+
590
+ # is_valid = (~is_current) | (
591
+ # (current_idx >= 0) & (current_idx < seq_len) & (input_ids != pad_token_id)[b, current_idx % seq_len]
592
+ # )
593
+
594
+ # return is_causal & is_valid
595
+
596
+ # # Define mask_mod function
597
+ # def mask_mod(b, h, q_idx, kv_idx):
598
+ # # Always apply causal constraint
599
+ # is_causal = q_idx >= kv_idx
600
+
601
+ # # Handle cache vs current tokens
602
+ # is_cache = kv_idx < cache_len
603
+ # current_idx = kv_idx - cache_len
604
+ # in_bounds = (current_idx >= 0) & (current_idx < seq_len)
605
+
606
+ # # For cache: always valid; For current: check padding
607
+ # not_pad = (input_ids[b, current_idx % seq_len] != pad_token_id) | ~in_bounds
608
+ # valid = is_cache | (not_pad & in_bounds)
609
+
610
+ # # Apply attention mask if provided
611
+ # if attention_mask is not None:
612
+ # q_idx_curr = q_idx - cache_len
613
+ # q_in_bounds = (q_idx_curr >= 0) & (q_idx_curr < seq_len)
614
+ # attn_valid = attention_mask[b, q_idx_curr % seq_len, current_idx % seq_len] | ~(in_bounds & q_in_bounds)
615
+ # valid = valid & (is_cache | attn_valid)
616
+
617
+ # return is_causal & valid
618
+
619
+ # Create block mask
620
+ block_mask = create_block_mask(
621
+ mask_mod,
622
+ B=batch_size,
623
+ H=None,
624
+ Q_LEN=seq_len,
625
+ KV_LEN=kv_length,
626
+ device=input_ids.device,
627
+ )
628
+
629
+ return block_mask
630
+
631
+ def forward(
632
+ self,
633
+ input_ids: torch.Tensor,
634
+ input_embeds: Optional[torch.Tensor] = None,
635
+ input_states: Optional[torch.Tensor] = None,
636
+ attention_mask: Optional[torch.Tensor] = None, # binary mask of shape q x kv, True=valid position
637
+ position_ids: Optional[torch.Tensor] = None,
638
+ labels: Optional[torch.Tensor] = None,
639
+ num_steps: Optional[torch.Tensor] = None,
640
+ past_key_values: Optional[ValidCache] = None,
641
+ output_details: dict = {
642
+ "return_logits": True,
643
+ "return_latents": True,
644
+ "return_head": False,
645
+ "return_stats": False,
646
+ },
647
+ use_cache: bool = False,
648
+ cache_position: Optional[torch.Tensor] = None,
649
+ init_scale: float = 1.0,
650
+ **kwargs,
651
+ ) -> CausalLMOutputRecurrentLatents:
652
+ # Support multiple position formats:
653
+ if position_ids is None and cache_position is None:
654
+ position_ids = torch.arange(input_ids.shape[1], device=self.device).unsqueeze(0)
655
+ elif cache_position is not None:
656
+ position_ids = cache_position.unsqueeze(0)
657
+
658
+ if input_embeds is None:
659
+ input_embeds = self.transformer.wte(input_ids) # type: ignore # types broken in 2.6+
660
+
661
+ if self.emb_scale != 1:
662
+ input_embeds = input_embeds * self.emb_scale # type: ignore
663
+
664
+ if use_cache and past_key_values is None:
665
+ past_key_values = HuginnDynamicCache()
666
+
667
+ prepared_attn_mask = None # self.compile_mask(input_ids, attention_mask, past_key_values)
668
+ block_idx = torch.tensor(-1, device=torch.device("cpu"), dtype=torch.long) # count in tensors for compile
669
+
670
+ freqs_cis = self.rotary_emb(input_embeds, position_ids)
671
+
672
+ # Non-recurrent prelude
673
+ for block in self.transformer.prelude: # type: ignore # types broken in 2.6+
674
+ block_idx += 1
675
+ input_embeds = block(input_embeds, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
676
+
677
+ # Main recurrence
678
+ x, num_steps_no_grad, num_steps_with_grad, xk, block_idx = self.iterate_forward(
679
+ input_embeds, # type: ignore # mystery typing error
680
+ input_states,
681
+ freqs_cis,
682
+ block_idx,
683
+ prepared_attn_mask,
684
+ past_key_values,
685
+ num_steps,
686
+ init_scale,
687
+ )
688
+ latent_states = x.clone().detach()
689
+
690
+ # Coda layers
691
+ block_idx = torch.tensor(0, device=torch.device("cpu"), dtype=torch.long) # use negative indices for head
692
+ for block in self.transformer.coda: # type: ignore # types broken in 2.6+
693
+ block_idx -= 1
694
+ x = block(x, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
695
+ x = self.transformer.ln_f(x) # type: ignore # types broken in 2.6+
696
+
697
+ # Prediction head, assuming labels really are labels and not equal to input_ids
698
+ if labels is not None:
699
+ logits = self.lm_head(x).float()
700
+ loss = torch.nn.functional.cross_entropy(
701
+ logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=-100
702
+ )
703
+ log_ppl = loss.clone().detach().exp()
704
+ else:
705
+ logits = self.lm_head(x)#.float()
706
+ loss, log_ppl = torch.as_tensor(0.0), torch.as_tensor(0.0)
707
+
708
+ return CausalLMOutputRecurrentLatents(
709
+ loss=loss,
710
+ log_ppl=log_ppl,
711
+ logits=logits if output_details["return_logits"] else None,
712
+ past_key_values=past_key_values,
713
+ hidden_states=x if output_details["return_head"] else None,
714
+ latent_states=latent_states if output_details["return_latents"] else None,
715
+ stats=self.get_stats(logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad)
716
+ if output_details["return_stats"]
717
+ else None,
718
+ )
719
+
720
+ @torch._dynamo.disable(recursive=False) # type: ignore
721
+ def iterate_forward(
722
+ self,
723
+ input_embeds: torch.Tensor,
724
+ input_states: torch.Tensor,
725
+ freqs_cis,
726
+ block_idx: torch.Tensor,
727
+ mask: Optional[BlockMask],
728
+ past_key_values: Optional[ValidCache] = None,
729
+ num_steps: Optional[torch.Tensor] = None,
730
+ init_scale: float = 1.0,
731
+ ):
732
+ x = xk = self.initialize_state(input_embeds, scale=init_scale) if input_states is None else input_states.clone()
733
+ if num_steps is None:
734
+ num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() # type: ignore
735
+ elif hasattr(num_steps, "__len__") and len(num_steps) > 1:
736
+ num_steps_no_grad, num_steps_with_grad = num_steps
737
+ else:
738
+ num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0) if not x.is_meta else 0
739
+
740
+ with torch.no_grad():
741
+ # ultra annoying in ddp due to
742
+ # https://discuss.pytorch.org/t/does-distributeddataparallel-work-with-torch-no-grad-and-find-unused-parameters-false/122594
743
+ # for now running with find_unused_params=True enabled even though the graph structure is (technically) clear
744
+ # and all parameters are always used
745
+ for no_grad_step in range(num_steps_no_grad):
746
+ xk = x
747
+ x, block_idx = self.core_block_forward(
748
+ xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, no_grad_step
749
+ )
750
+
751
+ for grad_step in range(num_steps_with_grad):
752
+ xk = x
753
+ x, block_idx = self.core_block_forward(
754
+ xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, num_steps_no_grad + grad_step
755
+ )
756
+ return x, num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx # type: ignore # types broken in 2.6+
757
+
758
+ def core_block_forward(
759
+ self,
760
+ x,
761
+ input_embeds,
762
+ freqs_cis,
763
+ mask: Optional[BlockMask],
764
+ past_key_values,
765
+ block_idx: torch.Tensor,
766
+ current_step: int | Tensor,
767
+ ):
768
+ x = self._maybe_inject_noise(x, current_step)
769
+ x = self.transformer.adapter(torch.cat([x, input_embeds.to(x.device)], dim=-1)) # type: ignore # types broken in 2.6+
770
+ for block in self.transformer.core_block: # type: ignore # types broken in 2.6+
771
+ block_idx += 1
772
+ x = block(x, freqs_cis, block_idx, mask, past_key_values)
773
+ return x, block_idx
774
+
775
+ @torch.no_grad()
776
+ def iterate_one_step(
777
+ self,
778
+ input_embeds,
779
+ input_states,
780
+ position_ids: Optional[torch.Tensor] = None,
781
+ cache_position: Optional[torch.Tensor] = None,
782
+ block_idx: torch.Tensor = torch.tensor(0, dtype=torch.long),
783
+ attention_mask: Optional[BlockMask] = None,
784
+ past_key_values: Optional[ValidCache] = None,
785
+ current_step: int = 0,
786
+ ):
787
+ if position_ids is None and cache_position is None:
788
+ freqs_cis = self.freqs_cis[:, : input_embeds.shape[1]]
789
+ elif position_ids is not None:
790
+ freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
791
+ elif cache_position is not None:
792
+ freqs_cis = self.freqs_cis[:, cache_position]
793
+ x, block_idx = self.core_block_forward(
794
+ input_states,
795
+ input_embeds,
796
+ freqs_cis,
797
+ attention_mask,
798
+ past_key_values,
799
+ block_idx,
800
+ current_step=current_step,
801
+ )
802
+ return x, block_idx, current_step + 1
803
+
804
+ def predict_from_latents(
805
+ self,
806
+ latents,
807
+ attention_mask: Optional[BlockMask] = None,
808
+ position_ids: Optional[torch.Tensor] = None,
809
+ cache_position: Optional[torch.Tensor] = None,
810
+ past_key_values: Optional[ValidCache] = None,
811
+ ):
812
+ if position_ids is None and cache_position is None:
813
+ freqs_cis = self.freqs_cis[:, : latents.shape[1]]
814
+ elif position_ids is not None:
815
+ freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
816
+ elif cache_position is not None:
817
+ freqs_cis = self.freqs_cis[:, cache_position]
818
+ x = self.transformer.ln_f(latents) # type: ignore # types broken in 2.6+
819
+ # Coda layers
820
+ block_idx = torch.tensor(0, device=torch.device("cpu"), dtype=torch.long) # use negative indices for head
821
+ for block in self.transformer.coda: # type: ignore # types broken in 2.6+
822
+ block_idx -= 1
823
+ x = block(x, freqs_cis, block_idx, attention_mask, past_key_values)
824
+ x = self.transformer.ln_f(x) # type: ignore # types broken in 2.6+
825
+
826
+ logits = self.lm_head(x).float()
827
+
828
+ return CausalLMOutputRecurrentLatents(
829
+ loss=torch.as_tensor(0.0),
830
+ log_ppl=torch.as_tensor(0.0),
831
+ logits=logits,
832
+ past_key_values=past_key_values,
833
+ latent_states=x,
834
+ )
835
+
836
+ def embed_inputs(
837
+ self,
838
+ input_ids: torch.Tensor,
839
+ attention_mask: Optional[torch.Tensor] = None,
840
+ position_ids: Optional[torch.Tensor] = None,
841
+ past_key_values: Optional[ValidCache] = None,
842
+ use_cache: bool = False,
843
+ cache_position: Optional[torch.Tensor] = None,
844
+ **kwargs,
845
+ ) -> tuple[torch.Tensor, torch.Tensor]:
846
+ # Support multiple position formats:
847
+ if position_ids is None and cache_position is None:
848
+ freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
849
+ elif position_ids is not None:
850
+ freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
851
+ elif cache_position is not None:
852
+ freqs_cis = self.freqs_cis[:, cache_position]
853
+
854
+ input_embeds = self.transformer.wte(input_ids) # type: ignore # types broken in 2.6+
855
+ prepared_attn_mask = self.compile_mask(input_ids, attention_mask)
856
+
857
+ if self.emb_scale != 1:
858
+ input_embeds = input_embeds * self.emb_scale # type: ignore
859
+
860
+ if use_cache and past_key_values is None:
861
+ past_key_values = HuginnDynamicCache()
862
+
863
+ block_idx = torch.tensor(-1, device=torch.device("cpu"), dtype=torch.long) # count in tensors for compile
864
+ # Non-recurrent prelude
865
+ for block in self.transformer.prelude: # type: ignore # types broken in 2.6+
866
+ block_idx += 1
867
+ input_embeds = block(input_embeds, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
868
+ return input_embeds, block_idx
869
+
870
+ @torch._dynamo.disable(recursive=False) # type: ignore
871
+ def randomized_iteration_sampler(self) -> tuple[torch.Tensor, torch.Tensor]:
872
+ """Outputs are long tensors so that they can be passed through compiled functions"""
873
+ t = max(self.config.mean_recurrence - self.config.mean_backprop_depth, 0)
874
+ s = self.config.mean_backprop_depth
875
+ if torch.rand((1,)).is_meta: # annoying clause to make meta-tensor-based flop counting work
876
+ # these values are only the mean TFLOPs of the randomized sampler
877
+ # Note that this clause also breaks the contract, and returns ints in meta tensor mode
878
+ return t, s # type: ignore
879
+ if self.training:
880
+ sigma = 0.5
881
+ mu = math.log(t + s) - (sigma**2 / 2)
882
+ rate = torch.zeros((1,)).log_normal_(mean=mu, std=sigma)
883
+ p = torch.poisson(torch.tensor([rate], dtype=torch.float)) + 1
884
+ n = torch.clamp(p - s, min=0)
885
+ k = torch.as_tensor(torch.minimum(torch.as_tensor(s), p))
886
+ else:
887
+ n, k = torch.as_tensor(self.config.mean_recurrence), torch.as_tensor(0)
888
+
889
+ return n.to(dtype=torch.long), k.to(dtype=torch.long)
890
+
891
+ def initialize_state(self, input_embeds, scale: float = 1.0):
892
+ x = torch.randn_like(input_embeds)
893
+ std = self.config.init_values["std"] * scale
894
+ if std > 0:
895
+ torch.nn.init.trunc_normal_(x, mean=0.0, std=std, a=-3 * std, b=3 * std)
896
+ if self.emb_scale != 1:
897
+ x = x * self.emb_scale
898
+ else:
899
+ x.zero_()
900
+ return x
901
+
902
+ def _maybe_inject_noise(self, x, current_step, renorm=False):
903
+ if self.config.test_time_noise > 0:
904
+ n = self.config.test_time_noise * self.config.init_values["std"] * self.emb_scale
905
+ if self.config.test_time_noise_type == "geom":
906
+ step1 = torch.as_tensor(current_step + 1, device=x.device) # need to cast for compile
907
+ x = x * (1 - n / step1) + torch.randn_like(x) * n / step1
908
+ elif self.config.test_time_noise_type == "sqrt":
909
+ step1sqrt = torch.as_tensor(current_step + 1, device=x.device).sqrt() # need to cast for compile
910
+ x = x * (1 - n / step1sqrt) + torch.randn_like(x) * n / step1sqrt
911
+ elif self.config.test_time_noise_type == "line":
912
+ noise = max(n, (self.config.mean_recurrence - current_step) / self.config.mean_recurrence) # type: ignore
913
+ x = x * (1 - noise) + torch.randn_like(x) * noise
914
+ elif self.config.test_time_noise_type == "chi":
915
+ noise = 2 * torch.rand(1, device=x.device, dtype=x.dtype) * n
916
+ x = x * (1 - noise) + torch.randn_like(x) * noise
917
+ elif self.config.test_time_noise_type == "fixed":
918
+ x = x * (1 - n) + torch.randn_like(x) * n
919
+ else:
920
+ raise ValueError()
921
+
922
+ if renorm:
923
+ x = self.transformer.core_block[-1].norm_4(x) # type: ignore moduledict types still broken in pytorch
924
+ return x
925
+
926
+ def prepare_inputs_for_generation(
927
+ self,
928
+ input_ids: torch.Tensor,
929
+ past_key_values: Optional[Cache] = None,
930
+ attention_mask: Optional[torch.Tensor] = None,
931
+ inputs_embeds: Optional[torch.FloatTensor] = None,
932
+ cache_position: Optional[torch.Tensor] = None,
933
+ cache_lookup_strategy: str = "full",
934
+ **kwargs,
935
+ ):
936
+ model_inputs = {}
937
+ model_inputs["cache_position"] = cache_position
938
+ current_input_length = input_ids.shape[1]
939
+
940
+ if past_key_values is not None:
941
+ if not isinstance(past_key_values, (HuginnDynamicCache, HuginnStaticCache)):
942
+ assert past_key_values.get_seq_length() == 0 # only replace empty caches
943
+ # Need to use custom cache, detect and replace HF cache if generate injects it
944
+ if isinstance(past_key_values, StaticCache):
945
+ past_key_values = HuginnStaticCache(
946
+ max_length=getattr(self.generation_config, "max_length", self.config.block_size),
947
+ max_num_steps=4 + kwargs.get("num_steps", self.config.mean_recurrence) * 4,
948
+ num_heads=self.config.num_key_value_heads,
949
+ hidden_dim=self.config.n_embd // self.config.num_attention_heads,
950
+ dtype=torch.bfloat16,
951
+ device=input_ids.device,
952
+ lookup_strategy=cache_lookup_strategy,
953
+ )
954
+ else:
955
+ past_key_values = HuginnDynamicCache(lookup_strategy=cache_lookup_strategy)
956
+ model_inputs["past_key_values"] = past_key_values if kwargs["use_cache"] else None
957
+ input_ids = input_ids[:, cache_position] # type: ignore
958
+
959
+ model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
960
+ if cache_position is None:
961
+ position_ids = torch.arange(current_input_length)[None, :].to(input_ids.device)
962
+ model_inputs["position_ids"] = position_ids[:, -current_input_length:].clone(
963
+ memory_format=torch.contiguous_format
964
+ ) # some form of position_ids is a critical argument for the model to correctly apply rope!
965
+
966
+ # forward all other entries
967
+ for key, value in kwargs.items():
968
+ if key not in model_inputs:
969
+ model_inputs[key] = value
970
+ return model_inputs
971
+
972
+ @torch.no_grad()
973
+ def generate(self, *args, **kwargs):
974
+ """Dispatcher - use HF generate in all normal cases."""
975
+ self.generation_config = args[1] if len(args) > 1 else self.generation_config
976
+ if any(k in kwargs for k in ("criterion", "exit_threshold")):
977
+ # print("Dispatching to custom generate_adaptive function call")
978
+ return self.generate_with_adaptive_compute(*args, **kwargs)
979
+ elif "continuous_compute" in kwargs:
980
+ # print("Dispatching to custom generate_minimal function call")
981
+ return self.generate_minimal(*args, **kwargs)
982
+ else:
983
+ return super().generate(*args, **kwargs)
984
+
985
+ @torch.no_grad()
986
+ def _prep_generate_args(
987
+ self,
988
+ input_ids: torch.Tensor,
989
+ generation_config: Optional[GenerationConfig] = None, # type: ignore
990
+ cache_lookup_strategy: str = "full",
991
+ model_kwargs: dict = {},
992
+ ):
993
+ # Setup
994
+ if generation_config is None:
995
+ generation_config: GenerationConfig = self.generation_config # type: ignore
996
+ if "max_new_tokens" in model_kwargs:
997
+ max_new_tokens = model_kwargs["max_new_tokens"]
998
+ if "max_length" in model_kwargs:
999
+ max_new_tokens = min(max_new_tokens, model_kwargs["max_length"] - input_ids.shape[1])
1000
+ else:
1001
+ max_length = model_kwargs.get("max_length", generation_config.max_length)
1002
+ max_new_tokens = max_length - input_ids.shape[1]
1003
+
1004
+ if "cache_implementation" not in model_kwargs or model_kwargs["cache_implementation"] == "dynamic":
1005
+ model_kwargs["past_key_values"] = HuginnDynamicCache(lookup_strategy=cache_lookup_strategy)
1006
+ else:
1007
+ model_kwargs["past_key_values"] = HuginnStaticCache(
1008
+ max_length=max_length,
1009
+ max_num_steps=4 + model_kwargs.get("num_steps", self.config.mean_recurrence) * 4,
1010
+ num_heads=self.config.num_key_value_heads,
1011
+ hidden_dim=self.config.n_embd // self.config.num_attention_heads,
1012
+ batch_size=input_ids.shape[0],
1013
+ dtype=torch.bfloat16,
1014
+ device=input_ids.device,
1015
+ lookup_strategy=cache_lookup_strategy,
1016
+ )
1017
+ model_kwargs["use_cache"] = True
1018
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
1019
+ return model_kwargs, generation_config, max_new_tokens
1020
+
1021
+ @torch.no_grad()
1022
+ def generate_minimal(
1023
+ self,
1024
+ input_ids: torch.Tensor,
1025
+ generation_config: Optional[GenerationConfig] = None, # type: ignore
1026
+ tokenizer=None,
1027
+ streamer=None,
1028
+ continuous_compute=False, # warm-start state / continuous CoT
1029
+ init_scale: float = 1.0,
1030
+ cache_lookup_strategy: str = "full",
1031
+ **model_kwargs,
1032
+ ) -> Union[torch.Tensor, dict[str, Any]]:
1033
+ """Minimal single-sequence generation. Template for more complicated generate tasks"""
1034
+ model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
1035
+ input_ids, generation_config, cache_lookup_strategy
1036
+ )
1037
+ stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
1038
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
1039
+
1040
+ # Set up continuous compute if enabled
1041
+ if continuous_compute:
1042
+ embedded_inputs, _ = self.embed_inputs(input_ids)
1043
+ model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale)
1044
+
1045
+ # Generate tokens
1046
+ batch_size = input_ids.shape[0]
1047
+ for _ in range(max_new_tokens):
1048
+ # Forward pass
1049
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1050
+ outputs = self(**model_inputs, init_scale=init_scale)
1051
+
1052
+ # Get next token
1053
+ next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
1054
+ next_token = self._sample_next_token(next_token_logits, generation_config)
1055
+
1056
+ # Append token to sequence
1057
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
1058
+
1059
+ if streamer:
1060
+ streamer.put(next_token.cpu())
1061
+
1062
+ # Update model kwargs
1063
+ model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
1064
+ if continuous_compute:
1065
+ model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
1066
+
1067
+ if stop_tokens is not None:
1068
+ for i in range(batch_size):
1069
+ if unfinished_sequences[i] and next_token[i, 0].item() in stop_tokens:
1070
+ unfinished_sequences[i] = 0
1071
+ if "stopping_criteria" in model_kwargs:
1072
+ unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None)
1073
+ if unfinished_sequences.max() == 0:
1074
+ break
1075
+
1076
+ if streamer:
1077
+ streamer.end()
1078
+
1079
+ if generation_config.return_dict_in_generate:
1080
+ return GenerateDecoderOnlyOutput(
1081
+ sequences=input_ids, # type: ignore
1082
+ scores=None,
1083
+ logits=None,
1084
+ attentions=None,
1085
+ hidden_states=None,
1086
+ past_key_values=model_kwargs.get("past_key_values"),
1087
+ )
1088
+ return input_ids
1089
+
1090
+ @torch.no_grad()
1091
+ def generate_with_adaptive_compute(
1092
+ self,
1093
+ input_ids: torch.Tensor,
1094
+ generation_config: Optional[GenerationConfig] = None, # type: ignore
1095
+ tokenizer=None,
1096
+ streamer=None,
1097
+ continuous_compute=False, # warm-start state / continuous CoT
1098
+ criterion="none", # off by default, turn on by choosing an exit criterion
1099
+ exit_threshold: Union[str, float, int] = "auto",
1100
+ init_scale: float = 1.0,
1101
+ cache_lookup_strategy: str = "full",
1102
+ **model_kwargs,
1103
+ ) -> Union[torch.Tensor, GenerateDecoderOnlyOutput]:
1104
+ """
1105
+ Generate tokens with adaptive compute. This is NOT the most efficient implementation.
1106
+ For batches, on each token, we iterate until the entire batch finishes.
1107
+ """
1108
+ model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
1109
+ input_ids, generation_config, cache_lookup_strategy, model_kwargs
1110
+ )
1111
+ max_steps = model_kwargs.get("num_steps", self.config.mean_recurrence)
1112
+ stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
1113
+ logit_type = dict(copy=True, dtype=torch.float32, device=input_ids.device)
1114
+ batch_size = input_ids.shape[0]
1115
+ compute_steps = []
1116
+
1117
+ # Set up continuous compute if enabled
1118
+ if continuous_compute:
1119
+ embedded_inputs, _ = self.embed_inputs(input_ids)
1120
+ model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale)
1121
+
1122
+ # Track which sequences have finished (using unfinished_sequences to match generate_minimal)
1123
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
1124
+
1125
+ # Generate tokens
1126
+ for _ in range(max_new_tokens):
1127
+ # Adaptive compute forward
1128
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1129
+ aux_inputs = {
1130
+ k: model_inputs[k] for k in ["cache_position", "past_key_values", "attention_mask"] if k in model_inputs
1131
+ }
1132
+ embedded_inputs, block_idx = self.embed_inputs(model_inputs["input_ids"], **aux_inputs)
1133
+ current_latents = (
1134
+ self.initialize_state(embedded_inputs, scale=init_scale)
1135
+ if not continuous_compute
1136
+ else model_kwargs["input_states"]
1137
+ )
1138
+
1139
+ # Initialize criterion tracking for each sequence in batch
1140
+ exit_values_per_seq = [[] for _ in range(batch_size)]
1141
+ compute_steps_per_seq = [0] * batch_size
1142
+ exit_reached = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
1143
+
1144
+ # Set up criterions based on selected strategy
1145
+ if criterion == "entropy-diff":
1146
+ entropy = torch.ones(batch_size, device=input_ids.device) * 100.0
1147
+ exit_threshold = 1e-3 if exit_threshold == "auto" else float(exit_threshold)
1148
+ elif criterion == "latent-diff":
1149
+ exit_threshold = 0.03 if exit_threshold == "auto" else float(exit_threshold)
1150
+ elif "kl" in criterion:
1151
+ V = self.config.padded_vocab_size
1152
+ log_probs = ((1 / V) * torch.ones(batch_size, V, dtype=torch.float, device=input_ids.device)).log()
1153
+ if criterion == "minp-kl":
1154
+ exit_threshold = 1e-6 if exit_threshold == "auto" else float(exit_threshold)
1155
+ else:
1156
+ exit_threshold = 5e-4 if exit_threshold == "auto" else float(exit_threshold)
1157
+ elif criterion == "argmax-stability":
1158
+ stable_for_n_steps = torch.zeros(batch_size, dtype=torch.long, device=input_ids.device)
1159
+ current_argmax = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) * -1
1160
+ exit_threshold = 5 if exit_threshold == "auto" else int(exit_threshold)
1161
+ elif criterion == "none":
1162
+ exit_threshold = 1.0 if exit_threshold == "auto" else float(exit_threshold)
1163
+ else:
1164
+ raise ValueError("Invalid adaptive compute strategy.")
1165
+
1166
+ next_token_logits = None
1167
+
1168
+ # Iterate through compute steps
1169
+ for compute_step in range(max_steps):
1170
+ prev_latents = current_latents.clone()
1171
+ current_latents, block_idx, _ = self.iterate_one_step(
1172
+ embedded_inputs,
1173
+ current_latents,
1174
+ block_idx=block_idx,
1175
+ **aux_inputs,
1176
+ current_step=compute_step,
1177
+ )
1178
+
1179
+ if _ > 0: # do not exit in prefill
1180
+ # Check exit condition for each sequence in batch
1181
+ if criterion == "entropy-diff":
1182
+ prev_entropy = entropy
1183
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
1184
+ logits: torch.Tensor = outputs.logits # type: ignore
1185
+ probs = F.softmax(logits[:, -1, :], dim=-1)
1186
+ entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
1187
+ exit_values = (entropy - prev_entropy).abs()
1188
+ elif criterion == "latent-diff":
1189
+ norm_diff = (prev_latents - current_latents).norm(dim=-1) / current_latents.norm(dim=-1)
1190
+ exit_values = norm_diff.mean(dim=-1)
1191
+ elif "kl" in criterion:
1192
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
1193
+ logits: torch.Tensor = outputs.logits # type: ignore
1194
+ prev_log_probs = log_probs
1195
+ if criterion == "minp-kl":
1196
+ probs = F.softmax(logits[:, -1, :].float(), dim=-1)
1197
+ max_probs = probs.max(dim=-1, keepdim=True)[0]
1198
+ probs_mask = probs < (0.1 * max_probs)
1199
+ masked_probs = probs.clone()
1200
+ masked_probs[probs_mask] = 1 / V
1201
+ probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)
1202
+ log_probs = probs.log()
1203
+ else:
1204
+ log_probs = F.log_softmax(logits[:, -1, :].float(), dim=-1)
1205
+ exit_values = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
1206
+ elif criterion == "argmax-stability":
1207
+ prev_argmax = current_argmax
1208
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
1209
+ logits: torch.Tensor = outputs.logits # type: ignore
1210
+ current_argmax = logits[:, -1, :].argmax(dim=-1)
1211
+ stable_for_n_steps = torch.where(
1212
+ current_argmax == prev_argmax, stable_for_n_steps + 1, torch.zeros_like(stable_for_n_steps)
1213
+ )
1214
+ exit_values = stable_for_n_steps
1215
+ elif criterion == "none":
1216
+ exit_values = torch.ones(batch_size, device=input_ids.device) * 2.0 * exit_threshold
1217
+
1218
+ # Record values and check exits for each sequence
1219
+ for i in range(batch_size):
1220
+ if not exit_reached[i] and unfinished_sequences[i].bool():
1221
+ exit_values_per_seq[i].append(exit_values[i].item())
1222
+
1223
+ # Check for new exits, respecting unfinished_sequences
1224
+ new_exits = (
1225
+ exit_values < exit_threshold
1226
+ if criterion != "argmax-stability"
1227
+ else exit_values >= exit_threshold
1228
+ )
1229
+ new_exits = new_exits & ~exit_reached & unfinished_sequences.bool()
1230
+
1231
+ if new_exits.any():
1232
+ exit_reached = exit_reached | new_exits
1233
+ if criterion == "latent-diff":
1234
+ # Normally we don't compute the output for latent-diff, but when there is an exit,
1235
+ # we need to compute and save the output
1236
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
1237
+ logits: torch.Tensor = outputs.logits # type: ignore
1238
+ if next_token_logits is None:
1239
+ next_token_logits = logits[:, -1, :].to(**logit_type) # type: ignore
1240
+ else:
1241
+ for i in range(batch_size):
1242
+ if new_exits[i]:
1243
+ next_token_logits[i] = logits[i, -1, :].to(**logit_type) # type: ignore
1244
+ for i in range(batch_size):
1245
+ if new_exits[i]:
1246
+ compute_steps_per_seq[i] = compute_step + 1
1247
+
1248
+ # If all sequences have exited or finished, break early
1249
+ if (exit_reached | ~unfinished_sequences.bool()).all():
1250
+ break
1251
+ # This else is if the for loop finished without breaking
1252
+ else:
1253
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
1254
+
1255
+ # For sequences that didn't exit early, use the final logits
1256
+ if next_token_logits is None:
1257
+ next_token_logits = outputs.logits[:, -1, :].to(**logit_type) # type: ignore
1258
+ else:
1259
+ for i in range(batch_size):
1260
+ if not exit_reached[i] and unfinished_sequences[i].bool():
1261
+ next_token_logits[i] = outputs.logits[i, -1, :].to(**logit_type) # type: ignore
1262
+ compute_steps_per_seq[i] = max_steps
1263
+
1264
+ # Save latent states for continuous compute if enabled
1265
+ if continuous_compute:
1266
+ model_kwargs["input_states"] = current_latents[:, -1:, :]
1267
+
1268
+ # Record compute steps for this token generation
1269
+ compute_steps.append([compute_steps_per_seq, exit_values_per_seq])
1270
+
1271
+ # Sample or select next token based on generation config
1272
+ next_token = self._sample_next_token(next_token_logits, generation_config)
1273
+
1274
+ # Append token to sequence
1275
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
1276
+
1277
+ if streamer:
1278
+ streamer.put(next_token.cpu())
1279
+
1280
+ # Update model kwargs for next iteration
1281
+ model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
1282
+
1283
+ # Check for stop tokens and update unfinished sequences
1284
+ for i in range(batch_size):
1285
+ if (
1286
+ unfinished_sequences[i].bool()
1287
+ and stop_tokens is not None
1288
+ and next_token[i, 0].item() in stop_tokens
1289
+ ):
1290
+ unfinished_sequences[i] = 0
1291
+
1292
+ # Apply any custom stopping criteria
1293
+ if "stopping_criteria" in model_kwargs:
1294
+ unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None)
1295
+
1296
+ # Break if all sequences are finished
1297
+ if unfinished_sequences.max() == 0:
1298
+ break
1299
+
1300
+ if streamer:
1301
+ streamer.end()
1302
+
1303
+ if generation_config.return_dict_in_generate:
1304
+ return GenerateDecoderOnlyOutput(
1305
+ sequences=input_ids, # type: ignore
1306
+ scores=compute_steps, # type: ignore
1307
+ logits=None,
1308
+ attentions=None,
1309
+ hidden_states=None,
1310
+ past_key_values=model_kwargs.get("past_key_values"),
1311
+ )
1312
+ return input_ids
1313
+
1314
+ def _get_stops(self, generation_config, tokenizer, model_kwargs):
1315
+ stop_tokens = {65504, 65505, 65508} # begin_text, end_text, end_turn
1316
+ if generation_config.eos_token_id is not None:
1317
+ stop_tokens.add(generation_config.eos_token_id)
1318
+ if "stopping_criteria" in model_kwargs and tokenizer is None:
1319
+ tokenizer = model_kwargs["stopping_criteria"][0].tokenizer
1320
+ if hasattr(generation_config, "stop_strings") and tokenizer and generation_config.stop_strings:
1321
+ for s in generation_config.stop_strings:
1322
+ token_id = tokenizer(s, add_special_tokens=False)["input_ids"][0]
1323
+ stop_tokens.add(token_id)
1324
+ return torch.tensor(list(stop_tokens))
1325
+
1326
+ def _sample_next_token(self, next_token_logits, generation_config):
1327
+ """Helper function to sample the next token."""
1328
+ if generation_config.do_sample:
1329
+ if generation_config.temperature:
1330
+ next_token_logits = next_token_logits.float() / generation_config.temperature
1331
+
1332
+ probs = F.softmax(next_token_logits, dim=-1)
1333
+
1334
+ # Apply top_k
1335
+ if generation_config.top_k:
1336
+ top_k_values, _ = torch.topk(probs, generation_config.top_k, dim=-1)
1337
+ min_values = top_k_values[:, -1].unsqueeze(-1).expand_as(probs)
1338
+ probs = torch.where(probs < min_values, torch.zeros_like(probs), probs)
1339
+
1340
+ # Apply top_p (nucleus sampling)
1341
+ if generation_config.top_p:
1342
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
1343
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
1344
+
1345
+ # Create mask for probs to keep
1346
+ remove_indices = cumulative_probs > generation_config.top_p
1347
+ remove_indices[:, 0] = False # Keep at least the top probability
1348
+
1349
+ # Convert sorted indices mask back to original indices mask
1350
+ mask = torch.zeros_like(probs, dtype=torch.bool)
1351
+ for i in range(probs.shape[0]):
1352
+ mask[i, sorted_indices[i, remove_indices[i]]] = True
1353
+
1354
+ probs = torch.where(mask, torch.zeros_like(probs), probs)
1355
+
1356
+ # Apply min_p
1357
+ if generation_config.min_p:
1358
+ max_probs = probs.max(dim=-1, keepdim=True)[0]
1359
+ min_p_threshold = generation_config.min_p * max_probs
1360
+ probs = torch.where(probs < min_p_threshold, torch.zeros_like(probs), probs)
1361
+
1362
+ # Renormalize probabilities
1363
+ probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-10)
1364
+
1365
+ # Sample from the distribution
1366
+ return torch.multinomial(probs, num_samples=1)
1367
+ else:
1368
+ return torch.argmax(next_token_logits, dim=-1, keepdim=True)
1369
+
1370
+ @torch.no_grad()
1371
+ def generate_speculative(
1372
+ self,
1373
+ input_ids: torch.Tensor,
1374
+ generation_config: Optional[GenerationConfig] = None, # type: ignore
1375
+ tokenizer=None,
1376
+ streamer=None,
1377
+ continuous_compute=False, # warm-start state / continuous CoT
1378
+ init_scale: float = 1.0,
1379
+ cache_lookup_strategy: str = "full",
1380
+ draft_steps=32,
1381
+ lookahead_for_draft=8,
1382
+ verification_threshold=1,
1383
+ num_steps: int = 32, # intercept deliberately
1384
+ **model_kwargs,
1385
+ ) -> Union[torch.Tensor, dict[str, Any]]:
1386
+ """Batched speculative decoding with per-sequence acceptance."""
1387
+ assert lookahead_for_draft > 0
1388
+ pad_id = 65509
1389
+ model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
1390
+ input_ids, generation_config, cache_lookup_strategy, model_kwargs
1391
+ )
1392
+ stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
1393
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
1394
+
1395
+ # Set up continuous compute if enabled
1396
+ if continuous_compute:
1397
+ embedded_inputs, _ = self.embed_inputs(input_ids)
1398
+ model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale)
1399
+
1400
+ tokens_generated = 0
1401
+ # Prefill cache with full num_steps
1402
+ if model_kwargs["past_key_values"].get_seq_length() == 0:
1403
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1404
+ outputs = self(**model_inputs, num_steps=num_steps, init_scale=init_scale)
1405
+ next_token = self._sample_next_token(
1406
+ outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32), generation_config
1407
+ )
1408
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
1409
+ tokens_generated += 1
1410
+ if streamer:
1411
+ streamer.put(next_token.cpu())
1412
+ model_kwargs["cache_position"] = torch.as_tensor(
1413
+ [model_inputs["past_key_values"].get_seq_length()], device=input_ids.device
1414
+ )
1415
+ if continuous_compute:
1416
+ model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
1417
+
1418
+ # Generate tokens
1419
+ batch_size, prefix_seq_len = input_ids.shape[0], input_ids.shape[1]
1420
+ accepted_tokens = []
1421
+
1422
+ while tokens_generated < max_new_tokens:
1423
+ ### Run the next draft ####
1424
+ drafted_inputs = input_ids.clone()
1425
+ current_len = input_ids.shape[1]
1426
+
1427
+ for _ in range(lookahead_for_draft):
1428
+ model_inputs = self.prepare_inputs_for_generation(drafted_inputs, **model_kwargs)
1429
+ outputs = self(**model_inputs, num_steps=draft_steps, init_scale=init_scale)
1430
+ next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32)
1431
+ next_token = self._sample_next_token(next_token_logits, generation_config)
1432
+ drafted_inputs = torch.cat([drafted_inputs, next_token], dim=-1)
1433
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
1434
+ if continuous_compute:
1435
+ model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
1436
+
1437
+ model_kwargs["past_key_values"].clear_last_k_entries(lookahead_for_draft)
1438
+
1439
+ ## Verify drafted tokens ###
1440
+ model_kwargs["cache_position"] = torch.arange(
1441
+ current_len - 1, current_len + lookahead_for_draft - 1, device=input_ids.device
1442
+ )
1443
+ model_inputs = self.prepare_inputs_for_generation(drafted_inputs, **model_kwargs)
1444
+ outputs = self(**model_inputs, num_steps=num_steps, init_scale=init_scale)
1445
+ verified_next_token_preds = outputs.logits.argmax(dim=-1)
1446
+
1447
+ if verification_threshold >= 1:
1448
+ mismatched_tokens = (
1449
+ verified_next_token_preds[:, -lookahead_for_draft:] != drafted_inputs[:, current_len:]
1450
+ )
1451
+ not_all_matched, first_mismatch = torch.max(mismatched_tokens, dim=1)
1452
+ else:
1453
+ verified_logits = outputs.logits[:, -lookahead_for_draft:, :]
1454
+ verified_probs = F.softmax(verified_logits, dim=-1)
1455
+ drafted_token_probs = torch.gather(
1456
+ verified_probs, -1, drafted_inputs[:, current_len:].unsqueeze(-1)
1457
+ ).squeeze(-1)
1458
+ max_probs = verified_probs.max(dim=-1)[0]
1459
+ verification_passed = drafted_token_probs >= verification_threshold * max_probs
1460
+ not_all_matched, first_mismatch = torch.max(~verification_passed, dim=1)
1461
+
1462
+ # Per-sequence acceptance handling
1463
+ acceptance_lengths = torch.where(not_all_matched, first_mismatch, lookahead_for_draft)
1464
+
1465
+ # Build next_tokens for each sequence
1466
+ next_tokens_batch = []
1467
+ for i in range(batch_size):
1468
+ seq_acceptance = acceptance_lengths[i].item()
1469
+ if not_all_matched[i] and seq_acceptance < lookahead_for_draft:
1470
+ # Accept up to mismatch + sample final token
1471
+ accepted_part = drafted_inputs[i : i + 1, current_len : current_len + seq_acceptance]
1472
+ final_token_logits = outputs.logits[i : i + 1, seq_acceptance, :].to(copy=True, dtype=torch.float32)
1473
+ final_token = self._sample_next_token(final_token_logits, generation_config)
1474
+ seq_tokens = torch.cat([accepted_part, final_token], dim=-1) if seq_acceptance > 0 else final_token
1475
+ else:
1476
+ # Accept all drafted tokens
1477
+ seq_tokens = drafted_inputs[i : i + 1, current_len : current_len + seq_acceptance]
1478
+ next_tokens_batch.append(seq_tokens)
1479
+
1480
+ # Clean up KV cache - only if any sequence had mismatches
1481
+ if not_all_matched.any():
1482
+ min_first_mismatch = first_mismatch.min().item()
1483
+ model_inputs["past_key_values"].clear_last_k_entries(lookahead_for_draft - min_first_mismatch - 1)
1484
+
1485
+ # Concatenate accepted tokens to input_ids
1486
+ batch_accepted_counts = [tokens.shape[1] for tokens in next_tokens_batch]
1487
+ max_len = max(batch_accepted_counts)
1488
+ padded_tokens = [
1489
+ torch.cat(
1490
+ [
1491
+ tokens,
1492
+ pad_id * torch.ones((1, max_len - tokens.shape[1]), dtype=tokens.dtype, device=tokens.device),
1493
+ ],
1494
+ dim=-1,
1495
+ )
1496
+ if tokens.shape[1] < max_len
1497
+ else tokens
1498
+ for tokens in next_tokens_batch
1499
+ ]
1500
+ next_tokens = torch.cat(padded_tokens, dim=0)
1501
+ input_ids = torch.cat([input_ids, next_tokens], dim=-1)
1502
+
1503
+ accepted_tokens.append(batch_accepted_counts)
1504
+ tokens_generated += max(batch_accepted_counts)
1505
+
1506
+ if streamer:
1507
+ streamer.put(next_tokens_batch[0].cpu())
1508
+
1509
+ model_kwargs["cache_position"] = torch.as_tensor(
1510
+ [model_inputs["past_key_values"].get_seq_length()], device=input_ids.device
1511
+ )
1512
+ if continuous_compute:
1513
+ model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
1514
+
1515
+ # Check stopping conditions
1516
+ if stop_tokens is not None:
1517
+ for i in range(batch_size):
1518
+ if unfinished_sequences[i] and torch.isin(next_tokens_batch[i], stop_tokens).any():
1519
+ unfinished_sequences[i] = 0
1520
+ if "stopping_criteria" in model_kwargs:
1521
+ unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None)
1522
+ if unfinished_sequences.max() == 0:
1523
+ break
1524
+
1525
+ if streamer:
1526
+ streamer.end()
1527
+
1528
+ # Cut off extraneous parts of the sequence per batch element
1529
+ if stop_tokens is not None:
1530
+ for i in range(batch_size):
1531
+ stop_positions = torch.isin(input_ids[i, prefix_seq_len:], stop_tokens).nonzero()
1532
+ if len(stop_positions) > 0:
1533
+ input_ids[i, prefix_seq_len + stop_positions[0].item() + 1 :] = pad_id
1534
+ # Trim tensor to remove columns that are pad_id across all sequences
1535
+ non_pad_mask = input_ids != pad_id
1536
+ last_real_token = non_pad_mask.any(dim=0).nonzero()
1537
+ if len(last_real_token) > 0:
1538
+ input_ids = input_ids[:, : last_real_token[-1].item() + 1]
1539
+
1540
+ if generation_config.return_dict_in_generate:
1541
+ return GenerateDecoderOnlyOutput(
1542
+ sequences=input_ids, # type: ignore
1543
+ scores=accepted_tokens, # type: ignore
1544
+ logits=None,
1545
+ attentions=None,
1546
+ hidden_states=None,
1547
+ past_key_values=model_kwargs.get("past_key_values"),
1548
+ )
1549
+ return input_ids
1550
+
1551
+ def get_stats(self, logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad):
1552
+ probs = torch.softmax(logits.float(), dim=-1)
1553
+ prob_entropy = torch.where(probs > 0, -probs * probs.log(), 0).sum(dim=-1)
1554
+ residual_diff = (x - latent_states).norm(dim=-1)
1555
+ rel_residual = residual_diff / latent_states.norm(dim=-1)
1556
+ stats = {
1557
+ "entropy": prob_entropy,
1558
+ "residual_diff": residual_diff,
1559
+ "rel_residual": rel_residual,
1560
+ "num_steps_no_grad": num_steps_no_grad,
1561
+ "num_steps_with_grad": num_steps_with_grad,
1562
+ }
1563
+ return stats
1564
+
1565
+
1566
+ #################################### HF registration ############################################################
1567
+
1568
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
1569
+
1570
+ # New
1571
+ RavenConfig.register_for_auto_class()
1572
+
1573
+ RavenForCausalLM.register_for_auto_class("AutoModel")
1574
+ RavenForCausalLM.register_for_auto_class("AutoModelForCausalLM")
1575
+
1576
+ # Old?
1577
+ AutoConfig.register("huginn_raven", RavenConfig)
1578
+ AutoModel.register(RavenConfig, RavenForCausalLM)
1579
+ AutoModelForCausalLM.register(RavenConfig, RavenForCausalLM)