Sharjeelbaig commited on
Commit
05657af
·
verified ·
1 Parent(s): 89655b1

Add transformers remote-code support for NeuroThinker pipeline loading

Browse files
config.json CHANGED
@@ -15,5 +15,17 @@
15
  "pad_token_id": 50256,
16
  "bos_token_id": 50256,
17
  "eos_token_id": 50260,
18
- "model_type": "neurothinker"
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  "pad_token_id": 50256,
16
  "bos_token_id": 50256,
17
  "eos_token_id": 50260,
18
+ "model_type": "neurothinker",
19
+ "architectures": [
20
+ "NeuroThinkerForCausalLM"
21
+ ],
22
+ "auto_map": {
23
+ "AutoConfig": "configuration_neurothinker.NeuroThinkerConfig",
24
+ "AutoModelForCausalLM": "modeling_neurothinker.NeuroThinkerForCausalLM"
25
+ },
26
+ "hidden_size": 384,
27
+ "num_hidden_layers": 6,
28
+ "num_attention_heads": 6,
29
+ "max_position_embeddings": 256,
30
+ "use_cache": false
31
+ }
configuration_neurothinker.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class NeuroThinkerConfig(PretrainedConfig):
5
+ model_type = "neurothinker"
6
+
7
+ def __init__(
8
+ self,
9
+ vocab_size=50261,
10
+ d_model=384,
11
+ n_layers=6,
12
+ n_heads=6,
13
+ d_head=64,
14
+ d_ff=720,
15
+ d_memory=192,
16
+ max_seq_len=256,
17
+ dropout=0.1,
18
+ rope_theta=10000.0,
19
+ memory_decay_init=0.99,
20
+ surprise_threshold=0.1,
21
+ rms_norm_eps=1e-6,
22
+ pad_token_id=50256,
23
+ bos_token_id=50256,
24
+ eos_token_id=50260,
25
+ use_cache=False,
26
+ **kwargs,
27
+ ):
28
+ self.vocab_size = vocab_size
29
+ self.d_model = d_model
30
+ self.n_layers = n_layers
31
+ self.n_heads = n_heads
32
+ self.d_head = d_head
33
+ self.d_ff = d_ff
34
+ self.d_memory = d_memory
35
+ self.max_seq_len = max_seq_len
36
+ self.dropout = dropout
37
+ self.rope_theta = rope_theta
38
+ self.memory_decay_init = memory_decay_init
39
+ self.surprise_threshold = surprise_threshold
40
+ self.rms_norm_eps = rms_norm_eps
41
+ # Common Transformer config aliases expected by generation utilities.
42
+ self.hidden_size = d_model
43
+ self.num_hidden_layers = n_layers
44
+ self.num_attention_heads = n_heads
45
+ self.max_position_embeddings = max_seq_len
46
+ self.use_cache = use_cache
47
+
48
+ super().__init__(
49
+ pad_token_id=pad_token_id,
50
+ bos_token_id=bos_token_id,
51
+ eos_token_id=eos_token_id,
52
+ **kwargs,
53
+ )
modeling_neurothinker.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from transformers import PreTrainedModel
7
+ from transformers.modeling_outputs import CausalLMOutput
8
+
9
+ from .configuration_neurothinker import NeuroThinkerConfig
10
+
11
+
12
+ class RMSNorm(nn.Module):
13
+ def __init__(self, d_model: int, eps: float = 1e-6):
14
+ super().__init__()
15
+ self.weight = nn.Parameter(torch.ones(d_model))
16
+ self.eps = eps
17
+
18
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
19
+ rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
20
+ return x / rms * self.weight
21
+
22
+
23
+ class SwiGLUFFN(nn.Module):
24
+ def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
25
+ super().__init__()
26
+ self.w_gate = nn.Linear(d_model, d_ff, bias=False)
27
+ self.w_up = nn.Linear(d_model, d_ff, bias=False)
28
+ self.w_down = nn.Linear(d_ff, d_model, bias=False)
29
+ self.dropout = nn.Dropout(dropout)
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ gate = F.silu(self.w_gate(x))
33
+ up = self.w_up(x)
34
+ return self.dropout(self.w_down(gate * up))
35
+
36
+
37
+ def precompute_rope_freqs(d_head: int, max_seq_len: int, theta: float = 10000.0, device=None):
38
+ freqs = 1.0 / (theta ** (torch.arange(0, d_head, 2, device=device).float() / d_head))
39
+ t = torch.arange(max_seq_len, device=device).float()
40
+ freqs = torch.outer(t, freqs)
41
+ return torch.polar(torch.ones_like(freqs), freqs)
42
+
43
+
44
+ def apply_rope(x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
45
+ x_pairs = x.float().reshape(*x.shape[:-1], -1, 2)
46
+ x_complex = torch.view_as_complex(x_pairs)
47
+ freqs = freqs.unsqueeze(0).unsqueeze(0)
48
+ x_rotated = x_complex * freqs[:, :, : x_complex.shape[2], :]
49
+ x_out = torch.view_as_real(x_rotated).reshape(*x.shape)
50
+ return x_out.type_as(x)
51
+
52
+
53
+ class RotaryMultiHeadAttention(nn.Module):
54
+ def __init__(self, d_model: int, n_heads: int, d_head: int, max_seq_len: int, dropout: float, rope_theta: float):
55
+ super().__init__()
56
+ self.n_heads = n_heads
57
+ self.d_head = d_head
58
+ self.scale = d_head ** -0.5
59
+
60
+ self.w_q = nn.Linear(d_model, n_heads * d_head, bias=False)
61
+ self.w_k = nn.Linear(d_model, n_heads * d_head, bias=False)
62
+ self.w_v = nn.Linear(d_model, n_heads * d_head, bias=False)
63
+ self.w_o = nn.Linear(n_heads * d_head, d_model, bias=False)
64
+ self.attn_dropout = nn.Dropout(dropout)
65
+ self.resid_dropout = nn.Dropout(dropout)
66
+
67
+ self.register_buffer(
68
+ "rope_freqs",
69
+ precompute_rope_freqs(d_head, max_seq_len, rope_theta),
70
+ persistent=False,
71
+ )
72
+
73
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
74
+ bsz, seq_len, _ = x.shape
75
+
76
+ q = self.w_q(x).view(bsz, seq_len, self.n_heads, self.d_head).transpose(1, 2)
77
+ k = self.w_k(x).view(bsz, seq_len, self.n_heads, self.d_head).transpose(1, 2)
78
+ v = self.w_v(x).view(bsz, seq_len, self.n_heads, self.d_head).transpose(1, 2)
79
+
80
+ q = apply_rope(q, self.rope_freqs[:seq_len].to(x.device))
81
+ k = apply_rope(k, self.rope_freqs[:seq_len].to(x.device))
82
+
83
+ attn = (q @ k.transpose(-2, -1)) * self.scale
84
+ if mask is not None:
85
+ attn = attn.masked_fill(mask == 0, float("-inf"))
86
+
87
+ attn = F.softmax(attn, dim=-1)
88
+ attn = self.attn_dropout(attn)
89
+
90
+ out = (attn @ v).transpose(1, 2).contiguous().view(bsz, seq_len, -1)
91
+ return self.resid_dropout(self.w_o(out))
92
+
93
+
94
+ class TitansMemoryModule(nn.Module):
95
+ def __init__(self, d_model: int, d_memory: int, decay_init: float = 0.99, dropout: float = 0.1):
96
+ super().__init__()
97
+ self.memory_net = nn.Sequential(
98
+ nn.Linear(d_model, d_memory, bias=False),
99
+ nn.SiLU(),
100
+ nn.Linear(d_memory, d_model, bias=False),
101
+ )
102
+ self.surprise_gate = nn.Sequential(
103
+ nn.Linear(d_model, d_model, bias=False),
104
+ nn.Sigmoid(),
105
+ )
106
+ self.forget_bias = nn.Parameter(torch.full((d_model,), decay_init))
107
+ self.momentum = nn.Parameter(torch.tensor(0.9))
108
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
109
+ self.dropout = nn.Dropout(dropout)
110
+ self.norm = nn.LayerNorm(d_model)
111
+ self.register_buffer("surprise_ema", torch.zeros(1))
112
+
113
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
114
+ memory_out = self.memory_net(x)
115
+ surprise_signal = torch.norm(x - memory_out, dim=-1, keepdim=True)
116
+ surprise_signal = surprise_signal / (surprise_signal.mean() + 1e-8)
117
+
118
+ momentum = torch.sigmoid(self.momentum)
119
+ smoothed = momentum * self.surprise_ema + (1 - momentum) * surprise_signal.mean()
120
+ self.surprise_ema = smoothed.detach()
121
+
122
+ gate = self.surprise_gate(x)
123
+ gate = gate * torch.clamp(surprise_signal, 0, 2)
124
+
125
+ forget = torch.sigmoid(self.forget_bias).unsqueeze(0).unsqueeze(0)
126
+ updated = forget * memory_out + gate * x
127
+
128
+ out = self.out_proj(updated)
129
+ out = self.dropout(out)
130
+ return self.norm(out + x)
131
+
132
+
133
+ class NeuroThinkerBlock(nn.Module):
134
+ def __init__(self, config: NeuroThinkerConfig):
135
+ super().__init__()
136
+ self.attn_norm = RMSNorm(config.d_model, config.rms_norm_eps)
137
+ self.attn = RotaryMultiHeadAttention(
138
+ d_model=config.d_model,
139
+ n_heads=config.n_heads,
140
+ d_head=config.d_head,
141
+ max_seq_len=config.max_seq_len,
142
+ dropout=config.dropout,
143
+ rope_theta=config.rope_theta,
144
+ )
145
+ self.memory_norm = RMSNorm(config.d_model, config.rms_norm_eps)
146
+ self.memory = TitansMemoryModule(
147
+ d_model=config.d_model,
148
+ d_memory=config.d_memory,
149
+ decay_init=config.memory_decay_init,
150
+ dropout=config.dropout,
151
+ )
152
+ self.ffn_norm = RMSNorm(config.d_model, config.rms_norm_eps)
153
+ self.ffn = SwiGLUFFN(config.d_model, config.d_ff, config.dropout)
154
+
155
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
156
+ x = x + self.attn(self.attn_norm(x), mask=mask)
157
+ x = self.memory(self.memory_norm(x))
158
+ x = x + self.ffn(self.ffn_norm(x))
159
+ return x
160
+
161
+
162
+ class NeuroThinkerForCausalLM(PreTrainedModel):
163
+ config_class = NeuroThinkerConfig
164
+ base_model_prefix = "neurothinker"
165
+ main_input_name = "input_ids"
166
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
167
+
168
+ def __init__(self, config: NeuroThinkerConfig):
169
+ super().__init__(config)
170
+ self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
171
+ self.blocks = nn.ModuleList([NeuroThinkerBlock(config) for _ in range(config.n_layers)])
172
+ self.final_norm = RMSNorm(config.d_model, config.rms_norm_eps)
173
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
174
+ self.lm_head.weight = self.token_emb.weight
175
+
176
+ self.post_init()
177
+
178
+ def _init_weights(self, module):
179
+ if isinstance(module, nn.Linear):
180
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
181
+ if module.bias is not None:
182
+ nn.init.zeros_(module.bias)
183
+ elif isinstance(module, nn.Embedding):
184
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
185
+
186
+ def get_input_embeddings(self):
187
+ return self.token_emb
188
+
189
+ def set_input_embeddings(self, new_embeddings):
190
+ self.token_emb = new_embeddings
191
+ self.lm_head.weight = self.token_emb.weight
192
+
193
+ def get_output_embeddings(self):
194
+ return self.lm_head
195
+
196
+ def set_output_embeddings(self, new_embeddings):
197
+ self.lm_head = new_embeddings
198
+
199
+ def _make_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
200
+ mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
201
+ return mask.unsqueeze(0).unsqueeze(0)
202
+
203
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
204
+ return {"input_ids": input_ids}
205
+
206
+ def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
207
+ if input_ids is None:
208
+ raise ValueError("input_ids is required")
209
+
210
+ bsz, seq_len = input_ids.shape
211
+ if seq_len > self.config.max_seq_len:
212
+ input_ids = input_ids[:, -self.config.max_seq_len :]
213
+ if labels is not None:
214
+ labels = labels[:, -self.config.max_seq_len :]
215
+ seq_len = input_ids.shape[1]
216
+
217
+ x = self.token_emb(input_ids)
218
+ mask = self._make_causal_mask(seq_len, x.device)
219
+
220
+ for block in self.blocks:
221
+ x = block(x, mask=mask)
222
+
223
+ x = self.final_norm(x)
224
+ logits = self.lm_head(x)
225
+ # Guard against numeric instability during sampling on small custom checkpoints.
226
+ logits = torch.nan_to_num(logits, nan=0.0, posinf=1e4, neginf=-1e4)
227
+ logits = torch.clamp(logits, min=-80.0, max=80.0)
228
+
229
+ loss = None
230
+ if labels is not None:
231
+ shift_logits = logits[:, :-1, :].contiguous()
232
+ shift_labels = labels[:, 1:].contiguous()
233
+ loss = F.cross_entropy(
234
+ shift_logits.view(-1, self.config.vocab_size),
235
+ shift_labels.view(-1),
236
+ ignore_index=-100,
237
+ )
238
+
239
+ return CausalLMOutput(loss=loss, logits=logits)