JonusNattapong commited on
Commit
b35f3f5
·
verified ·
1 Parent(s): 2f20c3c

Upload trained model

Browse files
Files changed (1) hide show
  1. modeling.py +496 -0
modeling.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from types import SimpleNamespace
3
+ import json
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.utils.checkpoint import checkpoint as grad_checkpoint
9
+
10
+
11
+ def rotate_every_two(x):
12
+ x1 = x[..., ::2]
13
+ x2 = x[..., 1::2]
14
+ return torch.stack((-x2, x1), dim=-1).reshape_as(x)
15
+
16
+
17
+ def apply_rotary_pos_emb(q, k, sin, cos):
18
+ # q,k: (B, nh, T, hs)
19
+ q_ = (q * cos) + (rotate_every_two(q) * sin)
20
+ k_ = (k * cos) + (rotate_every_two(k) * sin)
21
+ return q_, k_
22
+
23
+
24
+ class RotaryEmbedding(nn.Module):
25
+ def __init__(self, dim):
26
+ super().__init__()
27
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
28
+ self.register_buffer('inv_freq', inv_freq)
29
+
30
+ def forward(self, seq_len, device):
31
+ t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
32
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
33
+ emb = torch.cat((freqs, freqs), dim=-1) # (T, dim)
34
+ sin = emb.sin()[None, None, :, :]
35
+ cos = emb.cos()[None, None, :, :]
36
+ return sin, cos
37
+
38
+
39
+ class RMSNorm(nn.Module):
40
+ """Simple RMSNorm implementation compatible with HF's RMSNorm behavior."""
41
+ def __init__(self, dim, eps=1e-8):
42
+ super().__init__()
43
+ self.eps = eps
44
+ self.scale = nn.Parameter(torch.ones(dim))
45
+
46
+ def forward(self, x):
47
+ # x: (B, T, C)
48
+ norm = x.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
49
+ return x * norm * self.scale
50
+
51
+
52
+ class MultiHeadAttention(nn.Module):
53
+ def __init__(self, n_embd, n_head, attn_pdrop=0.1, resid_pdrop=0.1, use_rotary=True):
54
+ super().__init__()
55
+ assert n_embd % n_head == 0
56
+ self.n_head = n_head
57
+ self.head_dim = n_embd // n_head
58
+ self.scale = 1.0 / math.sqrt(self.head_dim)
59
+
60
+ self.qkv = nn.Linear(n_embd, n_embd * 3, bias=False)
61
+ self.proj = nn.Linear(n_embd, n_embd)
62
+ self.attn_dropout = nn.Dropout(attn_pdrop)
63
+ self.resid_dropout = nn.Dropout(resid_pdrop)
64
+ self.use_rotary = use_rotary
65
+ if use_rotary:
66
+ self.rotary = RotaryEmbedding(self.head_dim)
67
+
68
+ # optional flash attention detection
69
+ self.use_flash = False
70
+ try:
71
+ # try common flash attention package
72
+ import flash_attn # type: ignore
73
+ self.use_flash = True
74
+ except Exception:
75
+ self.use_flash = False
76
+
77
+ def forward(self, x, attn_mask=None):
78
+ B, T, C = x.size()
79
+ qkv = self.qkv(x).view(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4)
80
+ q, k, v = qkv[0], qkv[1], qkv[2] # each (B, nh, T, hs)
81
+
82
+ if self.use_rotary:
83
+ sin, cos = self.rotary(T, device=x.device)
84
+ q, k = apply_rotary_pos_emb(q, k, sin, cos)
85
+
86
+ if self.use_flash:
87
+ # best-effort: if flash attention is available, try to use it (APIs vary by package)
88
+ try:
89
+ # flatten for flash attention calls
90
+ qkv = torch.stack((q, k, v), dim=2)
91
+ # fallback to manual matmul if API unknown
92
+ raise RuntimeError('flash-attn integration placeholder; falling back')
93
+ except Exception:
94
+ att = torch.matmul(q, k.transpose(-2, -1)) * self.scale
95
+ else:
96
+ att = torch.matmul(q, k.transpose(-2, -1)) * self.scale
97
+
98
+ causal_mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
99
+ att = att.masked_fill(causal_mask == 0, float('-inf'))
100
+
101
+ if attn_mask is not None:
102
+ if attn_mask.dim() == 2:
103
+ attn_mask = attn_mask.view(B, 1, 1, T)
104
+ att = att.masked_fill(attn_mask == 0, float('-inf'))
105
+
106
+ att = F.softmax(att, dim=-1)
107
+ att = self.attn_dropout(att)
108
+ y = torch.matmul(att, v) # (B, nh, T, hs)
109
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
110
+ y = self.proj(y)
111
+ y = self.resid_dropout(y)
112
+ return y
113
+
114
+
115
+ class SwiGLU(nn.Module):
116
+ def __init__(self, dim_in, dim_out):
117
+ super().__init__()
118
+ # dim_out is the inner dim; we keep ability to set it equal to dim_in for smaller models
119
+ self.fc1 = nn.Linear(dim_in, dim_out)
120
+ self.fc_gate = nn.Linear(dim_in, dim_out)
121
+ self.fc2 = nn.Linear(dim_out, dim_in)
122
+ self.dropout = nn.Dropout(0.0)
123
+
124
+ def forward(self, x):
125
+ return self.fc2(F.silu(self.fc1(x)) * self.fc_gate(x))
126
+
127
+
128
+ class FeedForward(nn.Module):
129
+ def __init__(self, n_embd, mlp_ratio=1.0, pdrop=0.1, inner_dim=None):
130
+ super().__init__()
131
+ # Allow inner_dim override; default reduce to match embedding for compact model
132
+ if inner_dim is None:
133
+ inner = int(n_embd * mlp_ratio)
134
+ else:
135
+ inner = inner_dim
136
+ self.fn = SwiGLU(n_embd, inner)
137
+ self.dropout = nn.Dropout(pdrop)
138
+
139
+ def forward(self, x, tag_emb=None):
140
+ # tag_emb is accepted for API compatibility with MoE variants that may use router bias
141
+ return self.dropout(self.fn(x))
142
+
143
+
144
+ class MoEFeedForward(nn.Module):
145
+ """Mixture-of-Experts feedforward: small top-k router routing per token.
146
+
147
+ Notes: simplified router for resource-constrained mini models. Uses token-level routing.
148
+ """
149
+ def __init__(self, n_embd, num_experts=4, top_k=1, expert_ctor=None, router_temperature=1.0, aux_coef=0.0, tag_proj_dim=None):
150
+ super().__init__()
151
+ self.num_experts = num_experts
152
+ self.top_k = top_k
153
+ self.router_temperature = router_temperature
154
+ self.aux_coef = aux_coef
155
+ assert 1 <= top_k <= num_experts
156
+ if expert_ctor is None:
157
+ expert_ctor = lambda: FeedForward(n_embd)
158
+ self.experts = nn.ModuleList([expert_ctor() for _ in range(num_experts)])
159
+ # lightweight router: linear to num_experts
160
+ self.router = nn.Linear(n_embd, num_experts)
161
+ # optional projection from a tag embedding (B, C) -> (B, num_experts) to bias router logits
162
+ self.tag_proj = nn.Linear(tag_proj_dim, num_experts) if tag_proj_dim is not None else None
163
+
164
+ def forward(self, x, tag_emb=None):
165
+ # x: (B, T, C)
166
+ B, T, C = x.size()
167
+ logits = self.router(x) # (B, T, num_experts)
168
+ # if a tag embedding is provided (B, C) and we have a projection, add it as a bias
169
+ if tag_emb is not None and self.tag_proj is not None:
170
+ # project per-batch tag embedding to expert logits and broadcast to tokens
171
+ # tag_emb: (B, C) -> (B, num_experts) -> (B, 1, num_experts)
172
+ tag_bias = self.tag_proj(tag_emb).unsqueeze(1)
173
+ logits = logits + tag_bias
174
+ # apply temperature to router logits
175
+ if self.router_temperature and self.router_temperature != 1.0:
176
+ probs = F.softmax(logits / float(self.router_temperature), dim=-1)
177
+ else:
178
+ probs = F.softmax(logits, dim=-1)
179
+ # topk indices
180
+ topk = probs.topk(self.top_k, dim=-1)
181
+ indices = topk.indices # (B, T, top_k)
182
+ weights = topk.values # (B, T, top_k)
183
+
184
+ out = x.new_zeros(B, T, C)
185
+ # naive per-expert dispatch (may be slower but simple)
186
+ for e in range(self.num_experts):
187
+ # mask tokens that route to expert e
188
+ mask = (indices == e) # (B, T, top_k)
189
+ if not mask.any():
190
+ continue
191
+ # combine along top_k: compute contribution weight per (B,T)
192
+ # for tokens where expert e selected, create input slice
193
+ sel = mask.any(-1) # (B, T)
194
+ if not sel.any():
195
+ continue
196
+ inp = x[sel]
197
+ expert_out = self.experts[e](inp)
198
+ # add weighted contribution
199
+ # weights for those selected tokens: take max across top_k positions where index==e
200
+ w = torch.zeros(B, T, device=x.device)
201
+ for k in range(self.top_k):
202
+ w = w + (indices[..., k] == e).float() * weights[..., k]
203
+ w_sel = w[sel].unsqueeze(-1)
204
+ out[sel] = out[sel] + expert_out * w_sel
205
+
206
+ # compute lightweight auxiliary load-balancing loss (optional)
207
+ self.last_aux_loss = None
208
+ if getattr(self, 'aux_coef', 0.0):
209
+ # average probability mass per expert across tokens
210
+ load = probs.sum(dim=(0, 1)) / (B * T)
211
+ aux = (load * load).sum()
212
+ self.last_aux_loss = aux * float(self.aux_coef)
213
+
214
+ return out
215
+
216
+
217
+ class TransformerBlock(nn.Module):
218
+ def __init__(self, n_embd, n_head, mlp_ratio=4, attn_pdrop=0.1, resid_pdrop=0.1, use_rotary=True):
219
+ super().__init__()
220
+ self.ln1 = nn.LayerNorm(n_embd)
221
+ self.attn = MultiHeadAttention(n_embd, n_head, attn_pdrop, resid_pdrop, use_rotary=use_rotary)
222
+ self.ln2 = nn.LayerNorm(n_embd)
223
+ self.mlp = FeedForward(n_embd, mlp_ratio, resid_pdrop)
224
+
225
+ def forward(self, x, attn_mask=None, tag_emb=None):
226
+ x = x + self.attn(self.ln1(x), attn_mask=attn_mask)
227
+ # allow mlp variants (MoE) to accept tag_emb
228
+ x = x + (self.mlp(self.ln2(x), tag_emb=tag_emb) if hasattr(self.mlp, '__call__') else self.mlp(self.ln2(x)))
229
+ return x
230
+
231
+
232
+ class Hanuman(nn.Module):
233
+ """Hanuman: advanced GPT-like mini model with rotary embeddings and SwiGLU MLP.
234
+
235
+ Compatible forward signature with HF GPT2LMHeadModel: forward(input_ids, attention_mask, labels)
236
+ Returns SimpleNamespace(loss=..., logits=...)
237
+ """
238
+
239
+ def __init__(self, *, vocab_size, n_positions=4096, n_embd=512, n_layer=8, n_head=8, mlp_ratio=1.0,
240
+ attn_pdrop=0.1, resid_pdrop=0.1, use_rotary=True, use_rmsnorm=True, use_moe=False,
241
+ moe_experts=4, moe_top_k=1, gradient_checkpointing=False, use_think_head=False, think_aux_coef=1.0):
242
+ super().__init__()
243
+ self.vocab_size = vocab_size
244
+ self.n_positions = n_positions
245
+ self.n_embd = n_embd
246
+
247
+ self.use_rmsnorm = use_rmsnorm
248
+ self.gradient_checkpointing = gradient_checkpointing
249
+
250
+ self.wte = nn.Embedding(vocab_size, n_embd)
251
+ self.wpe = nn.Embedding(n_positions, n_embd)
252
+ self.drop = nn.Dropout(0.1)
253
+
254
+ self.blocks = nn.ModuleList()
255
+ for _ in range(n_layer):
256
+ blk = TransformerBlock(n_embd, n_head, mlp_ratio, attn_pdrop, resid_pdrop, use_rotary=use_rotary)
257
+ self.blocks.append(blk)
258
+
259
+ # final norm: RMSNorm or LayerNorm
260
+ if use_rmsnorm:
261
+ self.ln_f = RMSNorm(n_embd)
262
+ else:
263
+ self.ln_f = nn.LayerNorm(n_embd)
264
+
265
+ # optional MoE on top of feedforwards inside blocks: swap block.mlp with MoE variant
266
+ if use_moe:
267
+ for blk in self.blocks:
268
+ blk.mlp = MoEFeedForward(n_embd, num_experts=moe_experts, top_k=moe_top_k,
269
+ expert_ctor=lambda: FeedForward(n_embd, mlp_ratio=mlp_ratio, inner_dim=n_embd))
270
+
271
+ self.head = nn.Linear(n_embd, vocab_size, bias=False)
272
+
273
+ # optional think head for intermediate reasoning outputs (same vocab by default)
274
+ self.use_think_head = use_think_head
275
+ self.think_aux_coef = float(think_aux_coef)
276
+ if use_think_head:
277
+ self.think_head = nn.Linear(n_embd, vocab_size, bias=False)
278
+
279
+ def forward(self, input_ids=None, attention_mask=None, labels=None, thought_labels=None):
280
+ B, T = input_ids.size()
281
+ assert T <= self.n_positions, f"Sequence length {T} > model max {self.n_positions}"
282
+
283
+ pos = torch.arange(0, T, dtype=torch.long, device=input_ids.device).unsqueeze(0)
284
+ x = self.wte(input_ids) + self.wpe(pos)
285
+ x = self.drop(x)
286
+
287
+ # If user provided a special effort tag token (e.g., first token in input), compute tag_emb
288
+ tag_emb = None
289
+ try:
290
+ # detect if first token corresponds to a special think token id set on the model
291
+ if hasattr(self, 'think_token_ids') and isinstance(self.think_token_ids, dict):
292
+ # look for a single-tag indicator in input_ids (assumed at position 0)
293
+ first = input_ids[:, 0]
294
+ # if a known tag id is present, make tag_emb from its token embedding
295
+ for tag, tid in self.think_token_ids.items():
296
+ if (first == tid).any():
297
+ tag_emb = self.wte(tid).unsqueeze(0).expand(input_ids.size(0), -1)
298
+ break
299
+ except Exception:
300
+ tag_emb = None
301
+
302
+ for blk in self.blocks:
303
+ if self.gradient_checkpointing and self.training:
304
+ x = grad_checkpoint(blk, x, attention_mask, tag_emb)
305
+ else:
306
+ x = blk(x, attn_mask=attention_mask, tag_emb=tag_emb)
307
+
308
+ x = self.ln_f(x)
309
+ logits = self.head(x)
310
+
311
+ loss = None
312
+ thought_loss = None
313
+ if labels is not None:
314
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
315
+ lm_loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
316
+ loss = lm_loss
317
+
318
+ # optional thinking head loss
319
+ thought_logits = None
320
+ if self.use_think_head and thought_labels is not None:
321
+ thought_logits = self.think_head(x)
322
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
323
+ thought_loss = loss_fct(thought_logits.view(-1, thought_logits.size(-1)), thought_labels.view(-1))
324
+ if loss is None:
325
+ loss = thought_loss * self.think_aux_coef
326
+ else:
327
+ loss = loss + thought_loss * self.think_aux_coef
328
+
329
+ return SimpleNamespace(loss=loss, logits=logits, thought_logits=thought_logits, thought_loss=thought_loss)
330
+
331
+ # runtime helpers
332
+ def to_device(self, device):
333
+ self.to(device)
334
+
335
+ def enable_fp16(self):
336
+ # cast model params to float16 where safe
337
+ self.half()
338
+
339
+ def set_gradient_checkpointing(self, enabled: bool):
340
+ self.gradient_checkpointing = enabled
341
+
342
+ # Simple autoregressive generator (CPU/GPU). Not optimized for speed.
343
+ @torch.no_grad()
344
+ def generate(self, input_ids, max_new_tokens=50, temperature=1.0, top_k=0, top_p=0.0, eos_token_id=None):
345
+ device = input_ids.device
346
+ self.eval()
347
+ out = input_ids
348
+ for _ in range(max_new_tokens):
349
+ logits = self.forward(input_ids=out).logits
350
+ next_logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
351
+ if top_k > 0:
352
+ vals, idx = torch.topk(next_logits, top_k)
353
+ probs = torch.zeros_like(next_logits).scatter(1, idx, F.softmax(vals, dim=-1))
354
+ elif top_p > 0.0:
355
+ sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
356
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
357
+ cutoff = cumulative_probs > top_p
358
+ cutoff_index = torch.argmax(cutoff.int(), dim=-1)
359
+ mask = torch.zeros_like(sorted_logits).bool()
360
+ for b in range(sorted_logits.size(0)):
361
+ mask[b, :cutoff_index[b]+1] = True
362
+ probs = torch.zeros_like(next_logits)
363
+ probs.scatter_(1, sorted_indices, F.softmax(sorted_logits, dim=-1) * mask.float())
364
+ else:
365
+ probs = F.softmax(next_logits, dim=-1)
366
+
367
+ next_token = torch.multinomial(probs, num_samples=1)
368
+ out = torch.cat([out, next_token], dim=1)
369
+ if eos_token_id is not None and next_token.item() == eos_token_id:
370
+ break
371
+ return out
372
+
373
+ @torch.no_grad()
374
+ def generate_effort(self, input_ids, effort='short', reason_budget=None, temperature=1.0, top_k=0, top_p=0.0, eos_token_id=None):
375
+ """
376
+ Two-phase decoding: generate reasoning tokens inside a <scratch> block up to reason_budget, then generate final answer after <final>.
377
+ effort in {'none','short','medium','long'} maps to default budgets if reason_budget is None.
378
+ This is a simple, synchronous implementation; production should use batched, streaming decodes.
379
+ """
380
+ budget_map = {'none': 0, 'short': 64, 'medium': 256, 'long': 1024}
381
+ if reason_budget is None:
382
+ reason_budget = budget_map.get(effort, 64)
383
+
384
+ device = input_ids.device
385
+ model = self
386
+ # phase 1: generate reasoning tokens if budget > 0
387
+ out = input_ids
388
+ if reason_budget > 0:
389
+ for _ in range(reason_budget):
390
+ logits = model.forward(input_ids=out).logits
391
+ next_logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
392
+ probs = F.softmax(next_logits, dim=-1)
393
+ next_token = torch.multinomial(probs, num_samples=1)
394
+ out = torch.cat([out, next_token], dim=1)
395
+ # phase 2: generate final answer until eos or short fixed length
396
+ final_out = out
397
+ for _ in range(128):
398
+ logits = model.forward(input_ids=final_out).logits
399
+ next_logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
400
+ probs = F.softmax(next_logits, dim=-1)
401
+ next_token = torch.multinomial(probs, num_samples=1)
402
+ final_out = torch.cat([final_out, next_token], dim=1)
403
+ if eos_token_id is not None and next_token.item() == eos_token_id:
404
+ break
405
+ return final_out
406
+
407
+ # Utilities to play nice with train.py expectations
408
+ def save_pretrained(self, out_dir: str, use_safetensors: bool = False):
409
+ os.makedirs(out_dir, exist_ok=True)
410
+ # save state and a small config
411
+ model_path = os.path.join(out_dir, 'pytorch_model.bin')
412
+ cfg = {
413
+ 'vocab_size': self.vocab_size,
414
+ 'n_positions': self.n_positions,
415
+ 'n_embd': self.n_embd,
416
+ 'n_layer': len(self.blocks),
417
+ 'n_head': self.blocks[0].attn.n_head if len(self.blocks) else 0,
418
+ }
419
+ with open(os.path.join(out_dir, 'config.json'), 'w', encoding='utf-8') as f:
420
+ json.dump(cfg, f)
421
+
422
+ if use_safetensors:
423
+ try:
424
+ from safetensors.torch import save_file as safe_save
425
+ state = {k: v.cpu() for k, v in self.state_dict().items()}
426
+ safe_save(state, os.path.join(out_dir, 'pytorch_model.safetensors'))
427
+ return
428
+ except Exception:
429
+ # fallback to torch.save if safetensors isn't available
430
+ pass
431
+
432
+ torch.save(self.state_dict(), model_path)
433
+
434
+ @classmethod
435
+ def from_pretrained(cls, in_dir: str, map_location=None):
436
+ with open(os.path.join(in_dir, 'config.json'), 'r', encoding='utf-8') as f:
437
+ cfg = json.load(f)
438
+ model = cls(
439
+ vocab_size=cfg.get('vocab_size', 32000),
440
+ n_positions=cfg.get('n_positions', 1024),
441
+ n_embd=cfg.get('n_embd', 768),
442
+ n_layer=cfg.get('n_layer', 12),
443
+ n_head=cfg.get('n_head', 12),
444
+ )
445
+ # Prefer safetensors if present
446
+ safetensors_path = os.path.join(in_dir, 'pytorch_model.safetensors')
447
+ bin_path = os.path.join(in_dir, 'pytorch_model.bin')
448
+ if os.path.exists(safetensors_path):
449
+ try:
450
+ from safetensors.torch import load_file as safe_load
451
+ state = safe_load(safetensors_path, device=map_location or 'cpu')
452
+ except Exception:
453
+ state = torch.load(safetensors_path, map_location=map_location)
454
+ elif os.path.exists(bin_path):
455
+ state = torch.load(bin_path, map_location=map_location)
456
+ else:
457
+ raise FileNotFoundError(f'No model file found in {in_dir}')
458
+
459
+ # state is a mapping of tensors
460
+ model.load_state_dict(state)
461
+ return model
462
+
463
+ def resize_token_embeddings(self, new_vocab_size: int):
464
+ old_wte = self.wte
465
+ old_vocab, emb_dim = old_wte.weight.shape
466
+ if new_vocab_size == old_vocab:
467
+ return
468
+ new_wte = nn.Embedding(new_vocab_size, emb_dim)
469
+ # copy existing weights
470
+ with torch.no_grad():
471
+ new_wte.weight[:old_vocab] = old_wte.weight
472
+ self.wte = new_wte
473
+
474
+ new_head = nn.Linear(emb_dim, new_vocab_size, bias=False)
475
+ with torch.no_grad():
476
+ new_head.weight[:,:old_vocab] = self.head.weight
477
+ self.head = new_head
478
+
479
+
480
+ def build_from_config(config):
481
+ # Build Hanuman from a GPT2Config-like object with mini-model defaults
482
+ return Hanuman(
483
+ vocab_size=getattr(config, 'vocab_size', 32000),
484
+ n_positions=getattr(config, 'n_positions', getattr(config, 'n_ctx', 4096)),
485
+ n_embd=getattr(config, 'n_embd', 512),
486
+ n_layer=getattr(config, 'n_layer', 8),
487
+ n_head=getattr(config, 'n_head', 8),
488
+ mlp_ratio=getattr(config, 'mlp_ratio', 1.0),
489
+ use_rmsnorm=getattr(config, 'use_rmsnorm', True),
490
+ use_moe=getattr(config, 'use_moe', False),
491
+ moe_experts=getattr(config, 'moe_experts', 4),
492
+ moe_top_k=getattr(config, 'moe_top_k', 1),
493
+ gradient_checkpointing=getattr(config, 'gradient_checkpointing', False),
494
+ use_think_head=getattr(config, 'use_think_head', False),
495
+ think_aux_coef=getattr(config, 'think_aux_coef', 1.0),
496
+ )