ricdomolm commited on
Commit
9879210
·
verified ·
1 Parent(s): 9148728

upload ckpt-2000 (SWE-12h v2 lr=2e-5)

Browse files
README.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - code
7
+ - swe-bench
8
+ - agentic
9
+ - sft
10
+ library_name: transformers
11
+ pipeline_tag: text-generation
12
+ ---
13
+
14
+ # talkie-1930-coder
15
+
16
+ 13B model fine-tuned on agentic software-engineering trajectories from
17
+ [SWE-smith](https://github.com/SWE-bench/SWE-smith), starting from the
18
+ `talkie-1930` base. Tuned for the
19
+ [mini-swe-agent](https://github.com/SWE-bench/mini-swe-agent) interaction
20
+ format.
21
+
22
+ ## SWE-bench-Verified-Working-Harbor pass@1
23
+
24
+ | metric | value |
25
+ |---|---|
26
+ | **pass@1** (n=5 independent eval runs) | **4.48% ± 0.69 pp** |
27
+ | per-run resolved (out of 446) | 23, 18, 20, 23, 16 |
28
+
29
+ Eval pipeline: vLLM (`--model-impl transformers --max-model-len 32768
30
+ --dtype bfloat16`) → mini-swe-agent (`mini-extra swebench`, temperature 0.7,
31
+ `max_tokens=4096`), graded with the swebench harness against
32
+ `ricdomolm/SWE-bench_Verified-Working-Harbor`.
33
+
34
+ ## Training recipe
35
+
36
+ | | |
37
+ |---|---|
38
+ | Base model | `talkie-1930-13b-base` |
39
+ | Dataset | `talkie-1930-swe-100k-64k` (100k SWE-smith trajectories, packed at 64k) |
40
+ | Trainer | TRL `SFTTrainer` via `accelerate` (8× A100) |
41
+ | Optimizer | `adamw_torch_fused`, β=(0.9, 0.95), ε=1e-8 |
42
+ | LR | 2e-5, `cosine_with_min_lr`, warmup 3% |
43
+ | Precision | bf16 |
44
+ | Weight decay | 0.1 |
45
+ | Max grad norm | 30 |
46
+ | Max length | 65,536 |
47
+ | Packing | `bfd` + padding-free |
48
+ | Loss | `completion_only_loss=1` (loss only on assistant tokens) |
49
+ | Steps | 2,016 (this is ckpt-2000) |
50
+
51
+ ## Usage
52
+
53
+ This model uses custom modeling code (`modeling_talkie.py`,
54
+ `configuration_talkie.py`). Load with `trust_remote_code=True`:
55
+
56
+ ```python
57
+ from transformers import AutoModelForCausalLM, AutoTokenizer
58
+
59
+ model = AutoModelForCausalLM.from_pretrained(
60
+ "ricdomolm/talkie-1930-coder",
61
+ trust_remote_code=True,
62
+ torch_dtype="bfloat16",
63
+ )
64
+ tokenizer = AutoTokenizer.from_pretrained("ricdomolm/talkie-1930-coder")
65
+ ```
66
+
67
+ For agentic eval, serve with vLLM and drive with mini-swe-agent:
68
+
69
+ ```bash
70
+ vllm serve ricdomolm/talkie-1930-coder \
71
+ --model-impl transformers --max-model-len 32768 --dtype bfloat16
72
+ ```
73
+
74
+ ## Companion model
75
+
76
+ [`ricdomolm/talkie-web-coder`](https://huggingface.co/ricdomolm/talkie-web-coder)
77
+ — same recipe, same SFT data, but starting from a base model pre-trained
78
+ on web-style data. Reaches 5.75% ± 1.04 pp on the same eval (n=3).
chat_template.jinja ADDED
@@ -0,0 +1 @@
 
 
1
+ {% for message in messages %}{% if message['role'] == 'system' %}<|system|>{{ message['content'] }}<|end|>{% elif message['role'] == 'user' %}<|user|>{{ message['content'] }}<|end|>{% elif message['role'] == 'assistant' %}<|assistant|>{{ message['content'] }}<|end|>{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}
config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TalkieForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_talkie.TalkieConfig",
7
+ "AutoModelForCausalLM": "modeling_talkie.TalkieForCausalLM",
8
+ "AutoModel": "modeling_talkie.TalkieModel"
9
+ },
10
+ "dtype": "bfloat16",
11
+ "head_dim": 128,
12
+ "hidden_size": 5120,
13
+ "intermediate_size": 13696,
14
+ "max_position_embeddings": 65536,
15
+ "model_type": "talkie",
16
+ "num_attention_heads": 40,
17
+ "num_hidden_layers": 40,
18
+ "rope_theta": 40000000.0,
19
+ "tie_word_embeddings": false,
20
+ "transformers_version": "4.57.3",
21
+ "vocab_size": 65540
22
+ }
configuration_talkie.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Talkie model configuration for HuggingFace Transformers."""
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class TalkieConfig(PretrainedConfig):
7
+ """Configuration class for the Talkie 13B decoder-only transformer.
8
+
9
+ This is a 40-layer, 40-head GPT with RoPE, SwiGLU, RMS normalisation,
10
+ embedding skip connections, and per-head / per-layer gain parameters.
11
+ """
12
+
13
+ model_type = "talkie"
14
+
15
+ def __init__(
16
+ self,
17
+ vocab_size: int = 65540,
18
+ hidden_size: int = 5120,
19
+ intermediate_size: int = 13696,
20
+ num_hidden_layers: int = 40,
21
+ num_attention_heads: int = 40,
22
+ head_dim: int = 128,
23
+ max_position_embeddings: int = 2048,
24
+ rope_theta: float = 1_000_000.0,
25
+ torch_dtype: str = "bfloat16",
26
+ tie_word_embeddings: bool = False,
27
+ **kwargs,
28
+ ):
29
+ self.vocab_size = vocab_size
30
+ self.hidden_size = hidden_size
31
+ self.intermediate_size = intermediate_size
32
+ self.num_hidden_layers = num_hidden_layers
33
+ self.num_attention_heads = num_attention_heads
34
+ self.head_dim = head_dim
35
+ self.max_position_embeddings = max_position_embeddings
36
+ self.rope_theta = rope_theta
37
+ super().__init__(
38
+ tie_word_embeddings=tie_word_embeddings,
39
+ torch_dtype=torch_dtype,
40
+ **kwargs,
41
+ )
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "eos_token_id": [65536, 65535],
4
+ "pad_token_id": 65535,
5
+ "transformers_version": "4.57.3"
6
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfa38de8e0ee2609c788c8f76de0d12d1934044d71af22b04d94d49216c76fd8
3
+ size 26560565016
modeling_talkie.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Talkie 13B transformer — patched for long-context SFT.
2
+
3
+ Differences vs lewtun/talkie-1930-13b-it-hf upstream:
4
+ 1. Liger fused linear cross-entropy in the loss path so the float32 logits
5
+ tensor (shape S x V) is never materialised in HBM. Roughly 16 GB saved at
6
+ S=64K, V=65540.
7
+ 2. FlashAttention varlen path keyed off `position_ids`. When TRL passes a
8
+ packed sequence (padding_free=True), tokens from different documents do
9
+ not attend across boundaries.
10
+ 3. Gradient checkpointing on the decoder stack.
11
+ 4. RoPE precompute is configurable via config.max_position_embeddings; we set
12
+ it to 64K at load time.
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import math
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from transformers import GenerationMixin, PreTrainedModel
23
+ from transformers.modeling_outputs import (
24
+ BaseModelOutputWithPast,
25
+ CausalLMOutputWithPast,
26
+ )
27
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
28
+
29
+ from .configuration_talkie import TalkieConfig
30
+
31
+ try:
32
+ from flash_attn import flash_attn_varlen_func
33
+ _HAS_FA = True
34
+ except ImportError:
35
+ _HAS_FA = False
36
+
37
+ try:
38
+ from liger_kernel.transformers.fused_linear_cross_entropy import (
39
+ LigerFusedLinearCrossEntropyLoss,
40
+ )
41
+ _HAS_LIGER = True
42
+ except ImportError:
43
+ _HAS_LIGER = False
44
+
45
+
46
+ from dataclasses import dataclass, field
47
+
48
+
49
+ @dataclass
50
+ class TalkieCausalLMOutput(CausalLMOutputWithPast):
51
+ """CausalLMOutputWithPast plus a token_accuracy field expected by TRL when
52
+ SFTConfig.use_liger_kernel=True."""
53
+ token_accuracy: Optional[torch.Tensor] = None
54
+
55
+
56
+ class TalkieHeadGain(nn.Module):
57
+ def __init__(self, n_head: int):
58
+ super().__init__()
59
+ self.head_g = nn.Parameter(torch.ones(n_head))
60
+
61
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
62
+ return x * self.head_g.type_as(x).view(1, 1, -1, 1)
63
+
64
+
65
+ class TalkieWeightGain(nn.Module):
66
+ def __init__(self):
67
+ super().__init__()
68
+ self.w_g = nn.Parameter(torch.ones(1))
69
+
70
+ def forward(self, w: torch.Tensor) -> torch.Tensor:
71
+ return w * self.w_g.type_as(w)
72
+
73
+
74
+ class TalkieActGain(nn.Module):
75
+ def __init__(self, init_value: float):
76
+ super().__init__()
77
+ self.a_g = nn.Parameter(torch.ones(1) * init_value)
78
+
79
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
80
+ return x * self.a_g.type_as(x)
81
+
82
+
83
+ def _apply_rotary_emb(
84
+ x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
85
+ ) -> torch.Tensor:
86
+ assert x.ndim == 4
87
+ d = x.shape[3] // 2
88
+ x1 = x[..., :d]
89
+ x2 = x[..., d:]
90
+ y1 = x1 * cos + x2 * sin
91
+ y2 = x1 * (-sin) + x2 * cos
92
+ return torch.cat([y1, y2], 3).type_as(x)
93
+
94
+
95
+ def _precompute_rotary_embeddings(
96
+ seq_len: int, head_dim: int, base: float, device: torch.device
97
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
98
+ channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
99
+ inv_freq = 1.0 / (base ** (channel_range / head_dim))
100
+ t = torch.arange(seq_len, dtype=torch.float32, device=device)
101
+ freqs = torch.outer(t, inv_freq)
102
+ cos, sin = freqs.cos(), freqs.sin()
103
+ cos, sin = cos.bfloat16(), sin.bfloat16()
104
+ cos, sin = cos[None, :, None, :], sin[None, :, None, :]
105
+ return cos, sin
106
+
107
+
108
+ def _gather_rope_per_position(
109
+ cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor
110
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
111
+ """Index RoPE tables by position_ids.
112
+
113
+ cos/sin: (1, S_table, 1, D_half)
114
+ position_ids: (B, S)
115
+ returns (B, S, 1, D_half) bf16
116
+ """
117
+ cos_t = cos[0, :, 0, :] # (S_table, D_half)
118
+ sin_t = sin[0, :, 0, :]
119
+ flat = position_ids.reshape(-1)
120
+ cos_g = cos_t.index_select(0, flat).reshape(*position_ids.shape, 1, cos_t.shape[-1])
121
+ sin_g = sin_t.index_select(0, flat).reshape(*position_ids.shape, 1, sin_t.shape[-1])
122
+ return cos_g, sin_g
123
+
124
+
125
+ def _cu_seqlens_from_position_ids(position_ids: torch.Tensor) -> torch.Tensor:
126
+ """Convert per-token position_ids (where each new doc restarts at 0) into
127
+ cu_seqlens suitable for flash_attn_varlen_func.
128
+
129
+ Expects shape (B, S). For B>1 flatten before calling. Returns only cu_seqlens;
130
+ the caller can pass the total sequence length as an over-approximation of
131
+ max_seqlen to avoid a forced .item() sync (which torch.compile breaks on).
132
+ """
133
+ pos = position_ids.reshape(-1)
134
+ starts = (pos == 0).nonzero(as_tuple=False).squeeze(-1)
135
+ cu = torch.cat(
136
+ [starts, torch.tensor([pos.numel()], device=pos.device, dtype=starts.dtype)]
137
+ ).to(torch.int32)
138
+ return cu
139
+
140
+
141
+ class TalkieSelfAttention(nn.Module):
142
+ is_causal = True
143
+
144
+ def __init__(self, config: TalkieConfig, layer_idx: int = 0):
145
+ super().__init__()
146
+ self.config = config
147
+ self.layer_idx = layer_idx
148
+ self.n_head = config.num_attention_heads
149
+ self.head_dim = config.head_dim
150
+ self.scaling = 1.0 / math.sqrt(self.head_dim)
151
+ n_state = config.hidden_size
152
+
153
+ self.attn_query = nn.Linear(n_state, n_state, bias=False)
154
+ self.attn_key = nn.Linear(n_state, n_state, bias=False)
155
+ self.attn_value = nn.Linear(n_state, n_state, bias=False)
156
+ self.attn_resid = nn.Linear(n_state, n_state, bias=False)
157
+ self.head_gain = TalkieHeadGain(config.num_attention_heads)
158
+
159
+ def forward(
160
+ self,
161
+ x: torch.Tensor,
162
+ cos_sin: Tuple[torch.Tensor, torch.Tensor],
163
+ cu_seqlens: Optional[torch.Tensor] = None,
164
+ max_seqlen: Optional[int] = None,
165
+ **kwargs,
166
+ ) -> torch.Tensor:
167
+ bsz, seq_len, _ = x.size()
168
+ q = self.attn_query(x).view(bsz, seq_len, self.n_head, self.head_dim)
169
+ k = self.attn_key(x).view(bsz, seq_len, self.n_head, self.head_dim)
170
+ v = self.attn_value(x).view(bsz, seq_len, self.n_head, self.head_dim)
171
+
172
+ cos, sin = cos_sin
173
+ q, k = _apply_rotary_emb(q, cos, sin), _apply_rotary_emb(k, cos, sin)
174
+ q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),))
175
+ q = self.head_gain(q)
176
+
177
+ if cu_seqlens is not None and _HAS_FA:
178
+ assert bsz == 1, "varlen path expects flattened batch"
179
+ q_f = q.reshape(seq_len, self.n_head, self.head_dim)
180
+ k_f = k.reshape(seq_len, self.n_head, self.head_dim)
181
+ v_f = v.reshape(seq_len, self.n_head, self.head_dim)
182
+ y = flash_attn_varlen_func(
183
+ q_f,
184
+ k_f,
185
+ v_f,
186
+ cu_seqlens_q=cu_seqlens,
187
+ cu_seqlens_k=cu_seqlens,
188
+ max_seqlen_q=max_seqlen,
189
+ max_seqlen_k=max_seqlen,
190
+ causal=True,
191
+ )
192
+ y = y.reshape(bsz, seq_len, self.n_head * self.head_dim)
193
+ else:
194
+ attn_impl = getattr(self.config, "_attn_implementation", "sdpa")
195
+ attn_fn = ALL_ATTENTION_FUNCTIONS.get(attn_impl)
196
+ if attn_fn is None:
197
+ attn_fn = ALL_ATTENTION_FUNCTIONS["sdpa"]
198
+ y, _ = attn_fn(
199
+ self,
200
+ q.transpose(1, 2),
201
+ k.transpose(1, 2),
202
+ v.transpose(1, 2),
203
+ attention_mask=None,
204
+ scaling=self.scaling,
205
+ dropout=0.0,
206
+ is_causal=True,
207
+ **kwargs,
208
+ )
209
+ y = y.contiguous().view(bsz, seq_len, self.n_head * self.head_dim)
210
+ return self.attn_resid(y)
211
+
212
+
213
+ class TalkieMLP(nn.Module):
214
+ def __init__(self, config: TalkieConfig):
215
+ super().__init__()
216
+ n_state = config.hidden_size
217
+ n_mlp = config.intermediate_size
218
+
219
+ self.mlp_gate = nn.Linear(n_state, n_mlp, bias=False)
220
+ self.mlp_linear = nn.Linear(n_state, n_mlp, bias=False)
221
+ self.mlp_resid = nn.Linear(n_mlp, n_state, bias=False)
222
+
223
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
224
+ return self.mlp_resid(F.silu(self.mlp_gate(x)) * self.mlp_linear(x))
225
+
226
+
227
+ class TalkieDecoderLayer(nn.Module):
228
+ def __init__(self, config: TalkieConfig, layer_idx: int = 0):
229
+ super().__init__()
230
+ gain_init = (2 * config.num_hidden_layers) ** -0.5
231
+
232
+ self.layer_idx = layer_idx
233
+ self.attn = TalkieSelfAttention(config, layer_idx=layer_idx)
234
+ self.attn_gain = TalkieActGain(gain_init)
235
+ self.mlp = TalkieMLP(config)
236
+ self.mlp_gain = TalkieActGain(gain_init)
237
+ self.embed_skip = TalkieActGain(0.0)
238
+
239
+ def forward(
240
+ self,
241
+ e_x: torch.Tensor,
242
+ x: torch.Tensor,
243
+ cos_sin: Tuple[torch.Tensor, torch.Tensor],
244
+ cu_seqlens: Optional[torch.Tensor] = None,
245
+ max_seqlen: Optional[int] = None,
246
+ **kwargs,
247
+ ) -> torch.Tensor:
248
+ x = x + self.attn_gain(
249
+ self.attn(
250
+ F.rms_norm(x, (x.shape[-1],)),
251
+ cos_sin,
252
+ cu_seqlens,
253
+ max_seqlen,
254
+ **kwargs,
255
+ )
256
+ )
257
+ x = x + self.mlp_gain(self.mlp(F.rms_norm(x, (x.shape[-1],))))
258
+ x = x + self.embed_skip(e_x)
259
+ return x
260
+
261
+
262
+ class TalkieModel(PreTrainedModel):
263
+ """Decoder stack — HF-style forward so vLLM's transformers backend
264
+ (`AutoModel.from_config(...)`) can host this model."""
265
+
266
+ config_class = TalkieConfig
267
+ _no_split_modules = ["TalkieDecoderLayer"]
268
+ _supports_gradient_checkpointing = True
269
+ _supports_attention_backend = True
270
+ _supports_sdpa = True
271
+ _supports_flash_attn_2 = True
272
+ base_model_prefix = "model"
273
+ # Empty plan = single-GPU / replicate. Multi-GPU TP would need entries
274
+ # for q/k/v/o-proj. vLLM tolerates an empty plan when world_size==1.
275
+ tp_plan = {}
276
+
277
+ def __init__(self, config: TalkieConfig):
278
+ super().__init__(config)
279
+ self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
280
+ self.blocks = nn.ModuleList(
281
+ [
282
+ TalkieDecoderLayer(config, layer_idx=i)
283
+ for i in range(config.num_hidden_layers)
284
+ ]
285
+ )
286
+ self.gradient_checkpointing = False
287
+ # Selective activation checkpointing: only checkpoint every Nth layer.
288
+ # stride=1 => every layer (HF default), stride=2 => half of layers,
289
+ # stride=N => no layers checkpointed. Set via env at construction time.
290
+ import os as _os
291
+ try:
292
+ self.gc_stride = max(1, int(_os.environ.get("TALKIE_GC_STRIDE", "1")))
293
+ except ValueError:
294
+ self.gc_stride = 1
295
+
296
+ self._rope_cos: torch.Tensor | None = None
297
+ self._rope_sin: torch.Tensor | None = None
298
+
299
+ def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None):
300
+ self.gradient_checkpointing = enable
301
+
302
+ def get_input_embeddings(self):
303
+ return self.embed
304
+
305
+ def set_input_embeddings(self, value):
306
+ self.embed = value
307
+
308
+ def _get_rope(
309
+ self, seq_len: int, device: torch.device
310
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
311
+ target = max(seq_len, self.config.max_position_embeddings)
312
+ if (
313
+ self._rope_cos is None
314
+ or self._rope_cos.shape[1] < target
315
+ or self._rope_cos.device != device
316
+ ):
317
+ cos, sin = _precompute_rotary_embeddings(
318
+ target,
319
+ self.config.head_dim,
320
+ self.config.rope_theta,
321
+ device=device,
322
+ )
323
+ self._rope_cos = cos
324
+ self._rope_sin = sin
325
+ return self._rope_cos[:, :target], self._rope_sin[:, :target]
326
+
327
+ def forward(
328
+ self,
329
+ input_ids: Optional[torch.LongTensor] = None,
330
+ attention_mask: Optional[torch.Tensor] = None,
331
+ position_ids: Optional[torch.LongTensor] = None,
332
+ inputs_embeds: Optional[torch.Tensor] = None,
333
+ use_cache: Optional[bool] = None,
334
+ return_dict: Optional[bool] = None,
335
+ **kwargs,
336
+ ):
337
+ if inputs_embeds is None:
338
+ assert input_ids is not None
339
+ x = self.embed(input_ids)
340
+ seq_len = input_ids.shape[1]
341
+ device = input_ids.device
342
+ else:
343
+ x = inputs_embeds
344
+ seq_len = inputs_embeds.shape[1]
345
+ device = inputs_embeds.device
346
+
347
+ cos_table, sin_table = self._get_rope(seq_len, device)
348
+ if position_ids is not None:
349
+ cos_sin = _gather_rope_per_position(cos_table, sin_table, position_ids)
350
+ else:
351
+ cos_sin = (cos_table[:, :seq_len], sin_table[:, :seq_len])
352
+
353
+ # FlashAttention varlen path is for packed-sequence training only.
354
+ # During inference (HF generate, vLLM, etc.) we go through
355
+ # ALL_ATTENTION_FUNCTIONS instead.
356
+ cu_seqlens, max_seqlen = (None, None)
357
+ if self.training and position_ids is not None and _HAS_FA:
358
+ cu_seqlens = _cu_seqlens_from_position_ids(position_ids)
359
+ max_seqlen = seq_len
360
+
361
+ x = F.rms_norm(x, (x.shape[-1],))
362
+ e_x = x
363
+ for i, block in enumerate(self.blocks):
364
+ if (
365
+ self.gradient_checkpointing
366
+ and self.training
367
+ and (i % self.gc_stride == 0)
368
+ ):
369
+ x = torch.utils.checkpoint.checkpoint(
370
+ block,
371
+ e_x,
372
+ x,
373
+ cos_sin,
374
+ cu_seqlens,
375
+ max_seqlen,
376
+ use_reentrant=False,
377
+ )
378
+ else:
379
+ x = block(e_x, x, cos_sin, cu_seqlens, max_seqlen, **kwargs)
380
+ x = F.rms_norm(x, (x.shape[-1],))
381
+
382
+ if return_dict is False:
383
+ return (x,)
384
+ return BaseModelOutputWithPast(last_hidden_state=x)
385
+
386
+
387
+ class TalkieForCausalLM(PreTrainedModel, GenerationMixin):
388
+ config_class = TalkieConfig
389
+ _no_split_modules = ["TalkieDecoderLayer"]
390
+ _supports_gradient_checkpointing = True
391
+ supports_gradient_checkpointing = True
392
+ _supports_attention_backend = True
393
+ _supports_sdpa = True
394
+ _supports_flash_attn_2 = True
395
+
396
+ def __init__(self, config: TalkieConfig):
397
+ super().__init__(config)
398
+ self.model = TalkieModel(config)
399
+ self.lm_head = nn.Parameter(
400
+ torch.zeros(config.vocab_size, config.hidden_size)
401
+ )
402
+ self.lm_head_gain = TalkieWeightGain()
403
+
404
+ self.post_init()
405
+
406
+ def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None):
407
+ self.model.gradient_checkpointing = enable
408
+
409
+ def _get_rope(self, seq_len: int, device: torch.device):
410
+ # Backwards-compat shim for inference/fast_generate.py — RoPE tables
411
+ # now live on the inner TalkieModel.
412
+ return self.model._get_rope(seq_len, device)
413
+
414
+ def get_input_embeddings(self):
415
+ return self.model.embed
416
+
417
+ def set_input_embeddings(self, value):
418
+ self.model.embed = value
419
+
420
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
421
+ return {"input_ids": input_ids}
422
+
423
+ def forward(
424
+ self,
425
+ input_ids: Optional[torch.LongTensor] = None,
426
+ attention_mask: Optional[torch.Tensor] = None,
427
+ position_ids: Optional[torch.LongTensor] = None,
428
+ labels: Optional[torch.LongTensor] = None,
429
+ **kwargs,
430
+ ) -> Union[CausalLMOutputWithPast, Tuple]:
431
+ outputs = self.model(
432
+ input_ids=input_ids,
433
+ attention_mask=attention_mask,
434
+ position_ids=position_ids,
435
+ return_dict=False,
436
+ )
437
+ hidden_states = outputs[0]
438
+
439
+ loss = None
440
+ if labels is not None and _HAS_LIGER:
441
+ shift_hidden = hidden_states[..., :-1, :].contiguous()
442
+ shift_labels = labels[..., 1:].contiguous()
443
+ scaled_weight = self.lm_head_gain(self.lm_head)
444
+ loss_fn = LigerFusedLinearCrossEntropyLoss(return_token_accuracy=True)
445
+ res = loss_fn(
446
+ scaled_weight,
447
+ shift_hidden.view(-1, shift_hidden.size(-1)),
448
+ shift_labels.view(-1),
449
+ )
450
+ return TalkieCausalLMOutput(
451
+ loss=res.loss, logits=None, token_accuracy=res.token_accuracy,
452
+ )
453
+
454
+ logits = F.linear(hidden_states, self.lm_head_gain(self.lm_head))
455
+ if labels is not None:
456
+ shift_logits = logits[..., :-1, :].contiguous().float()
457
+ shift_labels = labels[..., 1:].contiguous()
458
+ loss = F.cross_entropy(
459
+ shift_logits.view(-1, shift_logits.size(-1)),
460
+ shift_labels.view(-1),
461
+ )
462
+ else:
463
+ logits = logits.float()
464
+
465
+ return CausalLMOutputWithPast(loss=loss, logits=logits)
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": null,
3
+ "clean_up_tokenization_spaces": false,
4
+ "eos_token": "<|end|>",
5
+ "model_max_length": 65536,
6
+ "pad_token": "<|endoftext|>",
7
+ "tokenizer_class": "PreTrainedTokenizerFast",
8
+ "unk_token": null
9
+ }