Anurich commited on
Commit
ca73a01
·
verified ·
1 Parent(s): 42565a9

Upload Jeeves model (trust_remote_code)

Browse files
README.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - jeeves
5
+ - causal-lm
6
+ - looped-transformer
7
+ - value-residual
8
+ - sentencepiece
9
+ license: apache-2.0
10
+ ---
11
+
12
+ # Jeeves (75M)
13
+
14
+ A compact language model using **Looped Transformer + Value Residual Learning**.
15
+
16
+ ## Usage
17
+
18
+ ```python
19
+ from transformers import AutoTokenizer, AutoModelForCausalLM
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained("REPO_ID", trust_remote_code=True)
22
+ model = AutoModelForCausalLM.from_pretrained("REPO_ID", trust_remote_code=True)
23
+
24
+ inputs = tokenizer("Hello, how are you?", return_tensors="pt")
25
+ outputs = model.generate(**inputs, max_new_tokens=50)
26
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
27
+ ```
28
+
29
+ **Note:** `trust_remote_code=True` is required.
30
+
31
+ ## Architecture
32
+
33
+ | Component | Value |
34
+ |---|---|
35
+ | Parameters | 74.9M |
36
+ | Unique layers | 8 |
37
+ | Effective depth | 15 |
38
+ | Loop | block[4] x 8 |
39
+ | Value residual | True |
40
+ | Hidden dim | 768 |
41
+ | FFN dim | 2048 |
42
+ | Attention heads | 12 (Q) / 4 (KV) |
43
+ | Vocab size | 32,000 |
44
+ | Max seq length | 512 |
45
+ | Training step | 1,100 |
46
+
47
+ ## Key Innovations
48
+
49
+ - **Looped Transformer** ([arXiv 2311.12424](https://arxiv.org/abs/2311.12424))
50
+ - **Value Residual Learning** ([arXiv 2410.17897](https://arxiv.org/abs/2410.17897))
51
+ - **Input Injection** for loop stability
config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "JeevesForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_jeeves.JeevesConfig",
7
+ "AutoModelForCausalLM": "modeling_jeeves.JeevesForCausalLM",
8
+ "AutoTokenizer": [
9
+ "tokenization_jeeves.JeevesTokenizer",
10
+ null
11
+ ]
12
+ },
13
+ "bos_token_id": 1,
14
+ "d_ff": 2048,
15
+ "d_model": 768,
16
+ "dropout": 0.0,
17
+ "dtype": "float32",
18
+ "eos_token_id": 2,
19
+ "head_dim": 64,
20
+ "hidden_size": 768,
21
+ "init_std": 0.02,
22
+ "loop_block_idx": 4,
23
+ "max_seq_len": 512,
24
+ "model_type": "jeeves",
25
+ "n_heads": 12,
26
+ "n_kv_heads": 4,
27
+ "n_layers": 8,
28
+ "n_loop_iters": 8,
29
+ "norm_eps": 1e-05,
30
+ "pad_token_id": 0,
31
+ "rope_base": 10000.0,
32
+ "tie_embeddings": true,
33
+ "tie_word_embeddings": true,
34
+ "transformers_version": "5.0.0",
35
+ "use_flash_attention": false,
36
+ "use_input_injection": true,
37
+ "use_value_residual": true,
38
+ "value_residual_alpha_init": -2.0,
39
+ "vocab_size": 32000
40
+ }
configuration_jeeves.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace-compatible configuration for Jeeves.
2
+
3
+ This file gets uploaded to the Hub so users can load with:
4
+ from transformers import AutoConfig
5
+ config = AutoConfig.from_pretrained("Anurich/Jeeves-Small-75M", trust_remote_code=True)
6
+ """
7
+
8
+ from transformers import PretrainedConfig
9
+
10
+
11
+ class JeevesConfig(PretrainedConfig):
12
+ """Configuration for the Jeeves language model.
13
+
14
+ Jeeves uses a Looped Transformer architecture with Value Residual Learning.
15
+ A single middle block is run N times (looped) with input injection,
16
+ giving effective depth much larger than the number of unique parameters.
17
+ """
18
+
19
+ model_type = "jeeves"
20
+
21
+ def __init__(
22
+ self,
23
+ d_model: int = 768,
24
+ n_layers: int = 8,
25
+ n_heads: int = 12,
26
+ n_kv_heads: int = 4,
27
+ vocab_size: int = 32000,
28
+ max_seq_len: int = 512,
29
+ d_ff: int = None,
30
+ norm_eps: float = 1e-5,
31
+ rope_base: float = 10000.0,
32
+ tie_embeddings: bool = True,
33
+ dropout: float = 0.0,
34
+ init_std: float = 0.02,
35
+ use_flash_attention: bool = True,
36
+ # Looped Transformer
37
+ loop_block_idx: int = None,
38
+ n_loop_iters: int = 1,
39
+ use_input_injection: bool = True,
40
+ # Value Residual Learning
41
+ use_value_residual: bool = False,
42
+ value_residual_alpha_init: float = -2.0,
43
+ # Special tokens
44
+ pad_token_id: int = 0,
45
+ bos_token_id: int = 1,
46
+ eos_token_id: int = 2,
47
+ **kwargs,
48
+ ):
49
+ # HF saves tie_word_embeddings in config.json; avoid passing it twice
50
+ kwargs.pop("tie_word_embeddings", None)
51
+ super().__init__(
52
+ pad_token_id=pad_token_id,
53
+ bos_token_id=bos_token_id,
54
+ eos_token_id=eos_token_id,
55
+ tie_word_embeddings=tie_embeddings,
56
+ **kwargs,
57
+ )
58
+ self.d_model = d_model
59
+ self.n_layers = n_layers
60
+ self.n_heads = n_heads
61
+ self.n_kv_heads = n_kv_heads
62
+ self.vocab_size = vocab_size
63
+ self.max_seq_len = max_seq_len
64
+ self.norm_eps = norm_eps
65
+ self.rope_base = rope_base
66
+ self.tie_embeddings = tie_embeddings
67
+ self.dropout = dropout
68
+ self.init_std = init_std
69
+ self.use_flash_attention = use_flash_attention
70
+
71
+ # Looped Transformer
72
+ self.loop_block_idx = loop_block_idx
73
+ self.n_loop_iters = n_loop_iters
74
+ self.use_input_injection = use_input_injection
75
+
76
+ # Value Residual Learning
77
+ self.use_value_residual = use_value_residual
78
+ self.value_residual_alpha_init = value_residual_alpha_init
79
+
80
+ # Compute FFN dimension
81
+ if d_ff is None:
82
+ raw = int(8 / 3 * d_model)
83
+ self.d_ff = ((raw + 255) // 256) * 256
84
+ else:
85
+ self.d_ff = d_ff
86
+
87
+ # Derived
88
+ self.head_dim = d_model // n_heads
89
+ self.hidden_size = d_model # HF convention
90
+
91
+ @property
92
+ def effective_depth(self) -> int:
93
+ if self.loop_block_idx is not None:
94
+ return (self.loop_block_idx + self.n_loop_iters
95
+ + (self.n_layers - self.loop_block_idx - 1))
96
+ return self.n_layers
generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_sample": true,
3
+ "temperature": 0.7,
4
+ "top_k": 50,
5
+ "top_p": 0.9,
6
+ "max_new_tokens": 512,
7
+ "eos_token_id": 2,
8
+ "pad_token_id": 0
9
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b009ae3742bb1ec3ee1357f6cfd5e66f75e9162b3616709d0c2beb3b758c7fa
3
+ size 299691520
modeling_jeeves.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace-compatible Jeeves model.
2
+
3
+ This file gets uploaded to the Hub so users can load with:
4
+ from transformers import AutoModelForCausalLM
5
+ model = AutoModelForCausalLM.from_pretrained("Anurich/Jeeves-Small-75M", trust_remote_code=True)
6
+
7
+ The architecture is self-contained — no local imports needed.
8
+ Features: Looped Transformer + Value Residual Learning + GQA + RoPE + SwiGLU.
9
+ """
10
+
11
+ import math
12
+ from typing import Optional, Tuple
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from transformers import GenerationMixin, PreTrainedModel
18
+ from transformers.modeling_outputs import CausalLMOutputWithPast
19
+
20
+ from .configuration_jeeves import JeevesConfig
21
+
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # Core layers
25
+ # ---------------------------------------------------------------------------
26
+
27
+ class RMSNorm(nn.Module):
28
+ """Root Mean Square Layer Normalization."""
29
+ def __init__(self, dim: int, eps: float = 1e-5):
30
+ super().__init__()
31
+ self.eps = eps
32
+ self.weight = nn.Parameter(torch.ones(dim))
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ output = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
36
+ return output.type_as(x) * self.weight
37
+
38
+
39
+ class SwiGLUFFN(nn.Module):
40
+ """SwiGLU Feed-Forward Network."""
41
+ def __init__(self, d_model: int, d_ff: int, dropout: float = 0.0):
42
+ super().__init__()
43
+ self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
44
+ self.up_proj = nn.Linear(d_model, d_ff, bias=False)
45
+ self.down_proj = nn.Linear(d_ff, d_model, bias=False)
46
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
47
+
48
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
49
+ return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))
50
+
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # RoPE
54
+ # ---------------------------------------------------------------------------
55
+
56
+ def precompute_rope_freqs(head_dim: int, max_seq_len: int, base: float = 10000.0,
57
+ device=None) -> torch.Tensor:
58
+ freqs = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
59
+ t = torch.arange(max_seq_len, device=device).float()
60
+ freqs = torch.outer(t, freqs)
61
+ return torch.polar(torch.ones_like(freqs), freqs)
62
+
63
+
64
+ def apply_rope(q, k, freqs_cis):
65
+ if q.device.type == 'mps':
66
+ return _apply_rope_real(q, k, freqs_cis)
67
+ q_c = torch.view_as_complex(q.float().reshape(*q.shape[:-1], -1, 2))
68
+ k_c = torch.view_as_complex(k.float().reshape(*k.shape[:-1], -1, 2))
69
+ f = freqs_cis.unsqueeze(0).unsqueeze(2)
70
+ q_r = torch.view_as_real(q_c * f).flatten(-2)
71
+ k_r = torch.view_as_real(k_c * f).flatten(-2)
72
+ return q_r.type_as(q), k_r.type_as(k)
73
+
74
+
75
+ def _apply_rope_real(q, k, freqs_cis):
76
+ cos = freqs_cis.real.unsqueeze(0).unsqueeze(2)
77
+ sin = freqs_cis.imag.unsqueeze(0).unsqueeze(2)
78
+
79
+ def _rotate(x):
80
+ pairs = x.float().reshape(*x.shape[:-1], -1, 2)
81
+ r, i = pairs[..., 0], pairs[..., 1]
82
+ out = torch.stack([r * cos - i * sin, r * sin + i * cos], dim=-1).flatten(-2)
83
+ return out.type_as(x)
84
+
85
+ return _rotate(q), _rotate(k)
86
+
87
+
88
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
89
+ if n_rep == 1:
90
+ return x
91
+ b, s, kv, d = x.shape
92
+ return x[:, :, :, None, :].expand(b, s, kv, n_rep, d).reshape(b, s, kv * n_rep, d)
93
+
94
+
95
+ # ---------------------------------------------------------------------------
96
+ # Attention with Value Residual Learning
97
+ # ---------------------------------------------------------------------------
98
+
99
+ class GQAWithValueResidual(nn.Module):
100
+ """Grouped-Query Attention with optional Value Residual Learning."""
101
+
102
+ def __init__(self, config: JeevesConfig):
103
+ super().__init__()
104
+ self.d_model = config.d_model
105
+ self.n_heads = config.n_heads
106
+ self.n_kv_heads = config.n_kv_heads
107
+ self.head_dim = config.head_dim
108
+ self.n_kv_groups = config.n_heads // config.n_kv_heads
109
+ self.use_flash_attention = config.use_flash_attention
110
+ self.use_value_residual = config.use_value_residual
111
+
112
+ self.q_proj = nn.Linear(config.d_model, config.n_heads * config.head_dim, bias=False)
113
+ self.k_proj = nn.Linear(config.d_model, config.n_kv_heads * config.head_dim, bias=False)
114
+ self.v_proj = nn.Linear(config.d_model, config.n_kv_heads * config.head_dim, bias=False)
115
+ self.o_proj = nn.Linear(config.n_heads * config.head_dim, config.d_model, bias=False)
116
+ self.attn_dropout = nn.Dropout(config.dropout) if config.dropout > 0 else nn.Identity()
117
+
118
+ if config.use_value_residual:
119
+ self.alpha_logit = nn.Parameter(torch.tensor(config.value_residual_alpha_init))
120
+
121
+ def forward(self, x, freqs_cis, mask=None, first_layer_v=None):
122
+ batch, seq_len, _ = x.shape
123
+ q = self.q_proj(x).view(batch, seq_len, self.n_heads, self.head_dim)
124
+ k = self.k_proj(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)
125
+ v = self.v_proj(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)
126
+ raw_v = v
127
+
128
+ if self.use_value_residual and first_layer_v is not None:
129
+ alpha = torch.sigmoid(self.alpha_logit)
130
+ v = (1.0 - alpha) * v + alpha * first_layer_v
131
+
132
+ q, k = apply_rope(q, k, freqs_cis)
133
+ k = repeat_kv(k, self.n_kv_groups)
134
+ v = repeat_kv(v, self.n_kv_groups)
135
+
136
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
137
+
138
+ is_accel = q.is_cuda or q.device.type == 'mps'
139
+ if self.use_flash_attention and is_accel:
140
+ attn_out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)
141
+ else:
142
+ scale = 1.0 / math.sqrt(self.head_dim)
143
+ scores = torch.matmul(q, k.transpose(-2, -1)) * scale
144
+ if mask is not None:
145
+ scores = scores + mask
146
+ w = F.softmax(scores, dim=-1, dtype=torch.float32).type_as(q)
147
+ w = self.attn_dropout(w)
148
+ attn_out = torch.matmul(w, v)
149
+
150
+ attn_out = attn_out.transpose(1, 2).contiguous().view(batch, seq_len, -1)
151
+ return self.o_proj(attn_out), raw_v
152
+
153
+
154
+ # ---------------------------------------------------------------------------
155
+ # Transformer Block
156
+ # ---------------------------------------------------------------------------
157
+
158
+ class TransformerBlock(nn.Module):
159
+ def __init__(self, config: JeevesConfig):
160
+ super().__init__()
161
+ self.attn_norm = RMSNorm(config.d_model, eps=config.norm_eps)
162
+ self.attention = GQAWithValueResidual(config)
163
+ self.ffn_norm = RMSNorm(config.d_model, eps=config.norm_eps)
164
+ self.ffn = SwiGLUFFN(config.d_model, config.d_ff, config.dropout)
165
+
166
+ def forward(self, x, freqs_cis, mask=None, first_layer_v=None):
167
+ h, raw_v = self.attention(self.attn_norm(x), freqs_cis, mask, first_layer_v)
168
+ x = x + h
169
+ x = x + self.ffn(self.ffn_norm(x))
170
+ return x, raw_v
171
+
172
+
173
+ # ---------------------------------------------------------------------------
174
+ # Jeeves Model (HuggingFace-compatible)
175
+ # ---------------------------------------------------------------------------
176
+
177
+ class JeevesForCausalLM(PreTrainedModel, GenerationMixin):
178
+ """Jeeves: Looped Transformer + Value Residual Learning.
179
+
180
+ Loads native Jeeves weights directly — no conversion needed.
181
+ """
182
+ config_class = JeevesConfig
183
+ supports_gradient_checkpointing = False
184
+ _tied_weights_keys = {"lm_head.weight": "tok_emb.weight"}
185
+
186
+ def __init__(self, config: JeevesConfig):
187
+ super().__init__(config)
188
+ self.config = config
189
+
190
+ # Embedding
191
+ self.tok_emb = nn.Embedding(config.vocab_size, config.d_model)
192
+
193
+ # Layer structure
194
+ if config.loop_block_idx is not None:
195
+ n_early = config.loop_block_idx
196
+ n_late = config.n_layers - config.loop_block_idx - 1
197
+ self.early_layers = nn.ModuleList([TransformerBlock(config) for _ in range(n_early)])
198
+ self.loop_block = TransformerBlock(config)
199
+ self.late_layers = nn.ModuleList([TransformerBlock(config) for _ in range(n_late)])
200
+ self.n_loop_iters = config.n_loop_iters
201
+ self.use_input_injection = config.use_input_injection
202
+ self.looped = True
203
+ else:
204
+ self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
205
+ self.looped = False
206
+
207
+ self.norm = RMSNorm(config.d_model, eps=config.norm_eps)
208
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
209
+
210
+ if config.tie_embeddings:
211
+ self.lm_head.weight = self.tok_emb.weight
212
+
213
+ # Store RoPE params — freqs_cis is computed fresh in forward()
214
+ # to avoid corruption from HF's meta-device initialization
215
+ self._rope_head_dim = config.head_dim
216
+ self._rope_max_seq_len = config.max_seq_len
217
+ self._rope_base = config.rope_base
218
+ self._freqs_cache = None
219
+
220
+ self.post_init()
221
+
222
+ def get_input_embeddings(self):
223
+ return self.tok_emb
224
+
225
+ def set_input_embeddings(self, value):
226
+ self.tok_emb = value
227
+
228
+ def get_output_embeddings(self):
229
+ return self.lm_head
230
+
231
+ def set_output_embeddings(self, new_embeddings):
232
+ self.lm_head = new_embeddings
233
+
234
+ def _get_freqs_cis(self, seq_len: int, device: torch.device) -> torch.Tensor:
235
+ """Get RoPE frequencies, computing and caching on first call."""
236
+ if self._freqs_cache is None or self._freqs_cache.device != device:
237
+ self._freqs_cache = precompute_rope_freqs(
238
+ self._rope_head_dim, self._rope_max_seq_len, self._rope_base, device
239
+ )
240
+ return self._freqs_cache[:seq_len]
241
+
242
+ def _make_causal_mask(self, seq_len, device):
243
+ mask = torch.full((seq_len, seq_len), float("-inf"), device=device)
244
+ return torch.triu(mask, diagonal=1)
245
+
246
+ def forward(
247
+ self,
248
+ input_ids: Optional[torch.LongTensor] = None,
249
+ attention_mask: Optional[torch.Tensor] = None,
250
+ labels: Optional[torch.LongTensor] = None,
251
+ inputs_embeds: Optional[torch.FloatTensor] = None,
252
+ **kwargs,
253
+ ) -> CausalLMOutputWithPast:
254
+ if inputs_embeds is None:
255
+ h = self.tok_emb(input_ids)
256
+ else:
257
+ h = inputs_embeds
258
+
259
+ batch, seq_len, _ = h.shape
260
+ device = h.device
261
+ freqs_cis = self._get_freqs_cis(seq_len, device)
262
+
263
+ mask = None
264
+ is_accel = h.is_cuda or h.device.type == 'mps'
265
+ if not self.config.use_flash_attention or not is_accel:
266
+ mask = self._make_causal_mask(seq_len, device)
267
+
268
+ first_layer_v = None
269
+
270
+ if self.looped:
271
+ # Early layers
272
+ for i, layer in enumerate(self.early_layers):
273
+ h, raw_v = layer(h, freqs_cis, mask, first_layer_v)
274
+ if i == 0 and self.config.use_value_residual:
275
+ first_layer_v = raw_v
276
+
277
+ # Looped block with input injection
278
+ loop_input = h
279
+ for loop_iter in range(self.n_loop_iters):
280
+ h, _ = self.loop_block(h, freqs_cis, mask, first_layer_v)
281
+ if self.use_input_injection and loop_iter < self.n_loop_iters - 1:
282
+ h = h + loop_input
283
+
284
+ # Late layers
285
+ for layer in self.late_layers:
286
+ h, _ = layer(h, freqs_cis, mask, first_layer_v)
287
+ else:
288
+ for i, layer in enumerate(self.layers):
289
+ h, raw_v = layer(h, freqs_cis, mask, first_layer_v)
290
+ if i == 0 and self.config.use_value_residual:
291
+ first_layer_v = raw_v
292
+
293
+ h = self.norm(h)
294
+ logits = self.lm_head(h)
295
+
296
+ loss = None
297
+ if labels is not None:
298
+ loss = F.cross_entropy(
299
+ logits.view(-1, self.config.vocab_size),
300
+ labels.view(-1),
301
+ ignore_index=-100,
302
+ )
303
+
304
+ return CausalLMOutputWithPast(
305
+ loss=loss,
306
+ logits=logits,
307
+ )
308
+
309
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
310
+ return {"input_ids": input_ids}
311
+
312
+ @staticmethod
313
+ def _reorder_cache(past, beam_idx):
314
+ return past
special_tokens_map.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "eos_token": "</s>",
4
+ "unk_token": "<unk>",
5
+ "pad_token": "<pad>",
6
+ "additional_special_tokens": [
7
+ "<|im_start|>",
8
+ "<|im_end|>",
9
+ "<|tool_call|>",
10
+ "<|tool_result|>",
11
+ "<|system|>",
12
+ "<|user|>",
13
+ "<|assistant|>"
14
+ ]
15
+ }
tokenization_jeeves.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom SentencePiece tokenizer for Jeeves model.
2
+
3
+ Wraps SentencePiece directly for exact token ID match with training.
4
+
5
+ Usage:
6
+ from transformers import AutoTokenizer
7
+ tokenizer = AutoTokenizer.from_pretrained("REPO_ID", trust_remote_code=True)
8
+ """
9
+
10
+ import os
11
+ from typing import Dict, List, Optional, Tuple
12
+
13
+ import sentencepiece as spm
14
+ from transformers import PreTrainedTokenizer
15
+
16
+
17
+ class JeevesTokenizer(PreTrainedTokenizer):
18
+ """SentencePiece BPE tokenizer for Jeeves models."""
19
+
20
+ vocab_files_names = {"vocab_file": "tokenizer.model"}
21
+ model_input_names = ["input_ids", "attention_mask"]
22
+
23
+ def __init__(
24
+ self,
25
+ vocab_file: str,
26
+ bos_token: str = "<s>",
27
+ eos_token: str = "</s>",
28
+ unk_token: str = "<unk>",
29
+ pad_token: str = "<pad>",
30
+ chat_template: Optional[str] = None,
31
+ additional_special_tokens: Optional[List[str]] = None,
32
+ **kwargs,
33
+ ):
34
+ self.vocab_file = vocab_file
35
+ self.sp_model = spm.SentencePieceProcessor()
36
+ self.sp_model.Load(vocab_file)
37
+
38
+ if additional_special_tokens is None:
39
+ additional_special_tokens = [
40
+ "<|im_start|>", "<|im_end|>",
41
+ "<|tool_call|>", "<|tool_result|>",
42
+ "<|system|>", "<|user|>", "<|assistant|>",
43
+ ]
44
+
45
+ super().__init__(
46
+ bos_token=bos_token, eos_token=eos_token,
47
+ unk_token=unk_token, pad_token=pad_token,
48
+ additional_special_tokens=additional_special_tokens,
49
+ chat_template=chat_template, **kwargs,
50
+ )
51
+
52
+ @property
53
+ def vocab_size(self) -> int:
54
+ return self.sp_model.GetPieceSize()
55
+
56
+ def get_vocab(self) -> Dict[str, int]:
57
+ vocab = {self.sp_model.IdToPiece(i): i for i in range(self.vocab_size)}
58
+ vocab.update(self.added_tokens_encoder)
59
+ return vocab
60
+
61
+ def _tokenize(self, text: str) -> List[str]:
62
+ return self.sp_model.EncodeAsPieces(text)
63
+
64
+ def _convert_token_to_id(self, token: str) -> int:
65
+ return self.sp_model.PieceToId(token)
66
+
67
+ def _convert_id_to_token(self, index: int) -> str:
68
+ if index < 0 or index >= self.vocab_size:
69
+ return self.unk_token
70
+ return self.sp_model.IdToPiece(index)
71
+
72
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
73
+ return self.sp_model.DecodePieces(tokens)
74
+
75
+ def save_vocabulary(
76
+ self, save_directory: str, filename_prefix: Optional[str] = None
77
+ ) -> Tuple[str]:
78
+ if not os.path.isdir(save_directory):
79
+ os.makedirs(save_directory, exist_ok=True)
80
+ out_path = os.path.join(
81
+ save_directory,
82
+ (filename_prefix + "-" if filename_prefix else "") + "tokenizer.model",
83
+ )
84
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_path):
85
+ import shutil
86
+ shutil.copyfile(self.vocab_file, out_path)
87
+ return (out_path,)
88
+
89
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
90
+ if token_ids_1 is None:
91
+ return token_ids_0
92
+ return token_ids_0 + token_ids_1
93
+
94
+ def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
95
+ if already_has_special_tokens:
96
+ return super().get_special_tokens_mask(
97
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1,
98
+ already_has_special_tokens=True,
99
+ )
100
+ n = len(token_ids_0) + (len(token_ids_1) if token_ids_1 else 0)
101
+ return [0] * n
102
+
103
+ def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
104
+ if token_ids_1 is None:
105
+ return [0] * len(token_ids_0)
106
+ return [0] * len(token_ids_0) + [1] * len(token_ids_1)
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46506a140c02b7c782c85f17e5bf6ff82b1fef925614d5adcc0f2d533c3100c3
3
+ size 539783
tokenizer_config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoTokenizer": [
4
+ "tokenization_jeeves.JeevesTokenizer",
5
+ null
6
+ ]
7
+ },
8
+ "tokenizer_class": "JeevesTokenizer",
9
+ "bos_token": "<s>",
10
+ "eos_token": "</s>",
11
+ "unk_token": "<unk>",
12
+ "pad_token": "<pad>",
13
+ "chat_template": "{% for message in messages %}<|im_start|>{{ message['role'] }}\n{{ message['content'] }}<|im_end|>\n{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}",
14
+ "add_bos_token": false,
15
+ "add_eos_token": false,
16
+ "clean_up_tokenization_spaces": false,
17
+ "additional_special_tokens": [
18
+ "<|im_start|>",
19
+ "<|im_end|>",
20
+ "<|tool_call|>",
21
+ "<|tool_result|>",
22
+ "<|system|>",
23
+ "<|user|>",
24
+ "<|assistant|>"
25
+ ]
26
+ }