Ubuntu commited on
Commit
ecd6fbd
·
1 Parent(s): 04cd47a
Files changed (3) hide show
  1. model.safetensors +2 -2
  2. src/pre-training.py +1 -1
  3. src/tynerox/modeling.py +130 -177
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2dc6c386af412163c51f18f97152117040b6464f9e64159ef464d50471ceda1c
3
- size 1101168184
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a469bc2dde18f9b248e83bf2d86d1d23fa32b8ad9646317d58e9d513b37b3120
3
+ size 832480488
src/pre-training.py CHANGED
@@ -35,7 +35,7 @@ if __name__ == "__main__":
35
  dataloader = create_train_dataloader(
36
  folder_path,
37
  tokenizer,
38
- batch_size=20,
39
  max_length=1024,
40
  drop_last=True,
41
  num_workers=10
 
35
  dataloader = create_train_dataloader(
36
  folder_path,
37
  tokenizer,
38
+ batch_size=5,
39
  max_length=1024,
40
  drop_last=True,
41
  num_workers=10
src/tynerox/modeling.py CHANGED
@@ -9,108 +9,37 @@ from typing import Optional, Literal, Union, Tuple
9
 
10
 
11
  class PositionalEncoding(nn.Module):
12
- """
13
- Implements positional encoding (sinusoidal or rotary).
14
- """
15
- def __init__(
16
- self,
17
- embed_dim: int,
18
- context_length: int,
19
- dropout: float = 0.1,
20
- encoding_type: Literal['sinusoidal', 'rotary'] = 'rotary',
21
- ):
22
  super().__init__()
23
- if embed_dim <= 0 or context_length <= 0:
24
- raise ValueError("embed_dim and context_length must be positive integers")
25
- if not 0 <= dropout < 1:
26
- raise ValueError("dropout must be between 0 and 1")
27
-
28
- self.dropout = nn.Dropout(dropout)
29
- self.encoding_type = encoding_type.lower()
30
- self.max_seq_len = context_length
31
  self.embed_dim = embed_dim
 
32
 
33
- if self.encoding_type == 'sinusoidal':
34
- pe = self._create_sinusoidal_embeddings(context_length, embed_dim)
35
- self.register_buffer('pe', pe.unsqueeze(0), persistent=True)
36
- elif self.encoding_type == 'rotary':
37
- if embed_dim % 2 != 0:
38
- raise ValueError("embed_dim must be even for rotary encoding")
39
- # inv_freq of size D/2
40
- inv_freq = 1.0 / (10000 ** (torch.arange(0, embed_dim, 2).float() / embed_dim))
41
- self.register_buffer('inv_freq', inv_freq, persistent=True)
42
- else:
43
- raise ValueError("Unsupported encoding_type: 'sinusoidal' or 'rotary'")
44
 
45
- def _create_sinusoidal_embeddings(self, max_seq_len: int, dim: int) -> torch.Tensor:
46
- position = torch.arange(max_seq_len).unsqueeze(1).float()
47
- div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
48
- pe = torch.zeros(max_seq_len, dim)
49
- pe[:, 0::2] = torch.sin(position * div_term)
50
- pe[:, 1::2] = torch.cos(position * div_term)
51
- return pe
 
52
 
53
- def forward(self, x: torch.Tensor) -> torch.Tensor:
54
- # x shape: [B, T, D]
55
- if self.encoding_type == 'sinusoidal':
56
- seq_len = x.size(1)
57
- x = x + self.pe[:, :seq_len, :]
58
- else:
59
- # rotary: split even/odd dims and apply rotary
60
- seq_len = x.size(1)
61
- positions = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
62
- # freqs of shape [T, D/2]
63
- freqs = torch.einsum('i , j -> i j', positions, self.inv_freq)
64
- x = self.apply_rotary(x, freqs)
65
- return self.dropout(x)
66
-
67
- @staticmethod
68
- def apply_rotary(x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
69
- # x: [B, T, D], emb: [T, D/2]
70
- x1, x2 = x.chunk(2, dim=-1) # each [B, T, D/2]
71
- emb_sin = emb.sin()[None, :, :] # [1, T, D/2]
72
- emb_cos = emb.cos()[None, :, :] # [1, T, D/2]
73
- # apply rotary
74
- rotated1 = x1 * emb_cos + x2 * emb_sin
75
- rotated2 = x2 * emb_cos - x1 * emb_sin
76
- return torch.cat([rotated1, rotated2], dim=-1) # [B, T, D]
77
-
78
-
79
- class PositionalEmbedding(nn.Module):
80
- """
81
- Combines token embedding with positional encoding.
82
- """
83
- def __init__(
84
- self,
85
- vocab_size: int,
86
- embed_dim: int,
87
- context_length: int,
88
- dropout: float = 0.05,
89
- encoding_type: Literal['sinusoidal', 'rotary'] = 'rotary'
90
- ):
91
- super().__init__()
92
- if vocab_size <= 0 or embed_dim <= 0 or context_length <= 0:
93
- raise ValueError("vocab_size, embed_dim, context_length must be > 0")
94
-
95
- self.token_embedding = nn.Embedding(vocab_size, embed_dim)
96
- self.scale = math.sqrt(embed_dim)
97
- self.pos_encoding = PositionalEncoding(
98
- embed_dim=embed_dim,
99
- context_length=context_length,
100
- dropout=dropout,
101
- encoding_type=encoding_type
102
- )
103
-
104
- def forward(self, input_ids: torch.LongTensor) -> torch.Tensor:
105
- # input_ids: [B, T]
106
- x = self.token_embedding(input_ids) * self.scale # [B, T, D]
107
- return self.pos_encoding(x)
108
 
109
 
110
  def get_alibi_slopes(n_heads: int) -> torch.Tensor:
111
  def _get_slopes(n):
112
  base = 2 ** (-8.0 / n)
113
  return torch.tensor([base ** (i + 1) for i in range(n)])
 
114
  if math.log2(n_heads).is_integer():
115
  return _get_slopes(n_heads)
116
  m = 2 ** math.floor(math.log2(n_heads))
@@ -118,6 +47,7 @@ def get_alibi_slopes(n_heads: int) -> torch.Tensor:
118
  extra = _get_slopes(2 * m)[::2][: n_heads - m]
119
  return torch.cat([slopes, extra], dim=0)
120
 
 
121
  # -----------------------------------------------------------------------------
122
  # Feed-Forward
123
  # -----------------------------------------------------------------------------
@@ -135,6 +65,7 @@ class FeedForward(nn.Module):
135
  x_up, x_gate = x_fc1.chunk(2, dim=-1)
136
  return self.fc2(x_up * self.activation(x_gate))
137
 
 
138
  # -----------------------------------------------------------------------------
139
  # Attention-Free Transformer (AFT) Simple
140
  # -----------------------------------------------------------------------------
@@ -143,6 +74,7 @@ class AFTSimple(nn.Module):
143
  def __init__(
144
  self,
145
  embed_dim: int,
 
146
  activation=torch.sigmoid,
147
  causal: bool = True,
148
  ):
@@ -151,6 +83,9 @@ class AFTSimple(nn.Module):
151
  self.causal = causal
152
  self.activation = activation
153
 
 
 
 
154
  self.qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
155
  self.project = nn.Linear(embed_dim, embed_dim)
156
 
@@ -159,53 +94,75 @@ class AFTSimple(nn.Module):
159
  x: torch.Tensor,
160
  past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
161
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
162
- # x: [B, T_new, D]
 
 
 
163
  B, T_new, D = x.shape
164
- if D != self.embed_dim:
165
- raise ValueError(f"Input dim ({D}) != embed_dim ({self.embed_dim})")
166
 
167
- qkv = self.qkv(x) # [B, T_new, 3*D]
168
- Q, K_new, V_new = qkv.chunk(3, dim=-1) # each [B, T_new, D]
 
169
 
170
- # concatenate past if provided
 
171
  if past_key_values is not None:
172
  K_past, V_past = past_key_values
173
- K = torch.cat([K_past, K_new], dim=1) # [B, T_all, D]
174
- V = torch.cat([V_past, V_new], dim=1)
175
  else:
176
- K, V = K_new, V_new
177
 
178
- # compute attention-free aggregate
179
- softmax_k = F.softmax(K, dim=1) # [B, T_all, D]
180
- weighted_v = softmax_k * V # [B, T_all, D]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  if self.causal:
183
- context = torch.cumsum(weighted_v, dim=1) # [B, T_all, D]
184
  else:
185
- total = weighted_v.sum(dim=1, keepdim=True) # [B, 1, D]
186
- context = total.expand(-1, K.size(1), -1) # [B, T_all, D]
 
 
 
187
 
188
- # slice only the new positions
189
- context_new = context[:, -T_new:, :] # [B, T_new, D]
190
- gate = self.activation(Q) # [B, T_new, D]
191
- Y = gate * context_new # [B, T_new, D]
192
- Y = self.project(Y) # [B, T_new, D]
193
 
194
- # return output and updated cache
195
  return Y, (K, V)
196
 
 
197
  # -----------------------------------------------------------------------------
198
  # Flash Attention with ALiBi and KV-cache
199
  # -----------------------------------------------------------------------------
200
 
201
  class FlashAttention(nn.Module):
202
  def __init__(
203
- self,
204
- embed_dim: int,
205
- num_heads: int,
206
- window_size: int,
207
- causal: bool = True,
208
- qkv_bias: bool = False,
209
  ):
210
  super().__init__()
211
  assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
@@ -221,9 +178,9 @@ class FlashAttention(nn.Module):
221
  self.register_buffer('alibi', get_alibi_slopes(num_heads))
222
 
223
  def forward(
224
- self,
225
- x: torch.Tensor,
226
- past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
227
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
228
  # x: [B, T_new, D]
229
  B, T_new, _ = x.size()
@@ -246,11 +203,12 @@ class FlashAttention(nn.Module):
246
  return_attn_probs=False,
247
  )
248
  # attn_out: [B, T_new, H, Dh]
249
- out = attn_out.contiguous().view(B, T_new, -1) # [B, T_new, D]
250
- y = self.out_proj(out) # [B, T_new, D]
251
 
252
  return y, (k, v)
253
 
 
254
  # -----------------------------------------------------------------------------
255
  # Transformer Blocks and Model
256
  # -----------------------------------------------------------------------------
@@ -259,7 +217,7 @@ class TransformerBlock(nn.Module):
259
  def __init__(self, config, att_global: bool = True):
260
  super().__init__()
261
  if att_global:
262
- self.attn = AFTSimple(embed_dim=config.d_model, causal=config.causal)
263
  else:
264
  self.attn = FlashAttention(
265
  embed_dim=config.d_model,
@@ -268,18 +226,18 @@ class TransformerBlock(nn.Module):
268
  causal=config.causal,
269
  qkv_bias=True,
270
  )
 
271
  self.ff = nn.Sequential(
272
  FeedForward(config.d_model),
273
- FeedForward(config.d_model),
274
  )
275
  self.ln1 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
276
  self.ln2 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
277
  self.drop = nn.Dropout(config.dropout)
278
 
279
  def forward(
280
- self,
281
- x: torch.Tensor,
282
- past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
283
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
284
  # Attention + residual
285
  residual = x
@@ -295,6 +253,7 @@ class TransformerBlock(nn.Module):
295
 
296
  return x, present
297
 
 
298
  class ResidualBlocks(nn.Module):
299
  def __init__(self, config):
300
  super().__init__()
@@ -307,10 +266,10 @@ class ResidualBlocks(nn.Module):
307
  self.final_ln = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
308
 
309
  def forward(
310
- self,
311
- x: torch.Tensor,
312
- past_key_values: Optional[Tuple[Tuple[torch.Tensor,torch.Tensor], ...]] = None
313
- ) -> Tuple[torch.Tensor, Tuple[Tuple[torch.Tensor,torch.Tensor], ...]]:
314
  new_past = []
315
  for i, layer in enumerate(self.layers):
316
  pkv = None if past_key_values is None else past_key_values[i]
@@ -319,6 +278,7 @@ class ResidualBlocks(nn.Module):
319
  x = self.final_ln(x)
320
  return x, tuple(new_past)
321
 
 
322
  # -----------------------------------------------------------------------------
323
  # Configuration and Model
324
  # -----------------------------------------------------------------------------
@@ -327,19 +287,19 @@ class TyneRoxConfig(PretrainedConfig):
327
  model_type = "tynerox"
328
 
329
  def __init__(
330
- self,
331
- vocab_size: int = 30522,
332
- context_length: int = 2048,
333
- d_model: int = 1024,
334
- num_heads: int = 16,
335
- window_size: int = 512,
336
- num_hidden_layers: int = 12,
337
- causal: bool = True,
338
- dropout: float = 0.1,
339
- layer_norm_eps: float = 1e-5,
340
- tie_word_embeddings: bool = False,
341
- pad_token_id:int = 0,
342
- **kwargs
343
  ):
344
  super().__init__(**kwargs)
345
  self.vocab_size = vocab_size
@@ -354,27 +314,29 @@ class TyneRoxConfig(PretrainedConfig):
354
  self.tie_word_embeddings = tie_word_embeddings
355
  self.pad_token_id = pad_token_id
356
 
 
357
  class TyneRoxModel(PreTrainedModel, GenerationMixin):
358
  config_class = TyneRoxConfig
359
 
360
  def __init__(self, config: TyneRoxConfig):
361
  super().__init__(config)
362
- self.embed = PositionalEmbedding(
363
- config.vocab_size,
364
- config.d_model,
365
- config.max_position_embeddings,
366
- dropout=config.dropout,
367
- encoding_type='rotary'
368
- )
369
  self.transformer = ResidualBlocks(config)
370
  self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
 
 
 
 
 
371
  self.post_init()
372
 
 
373
  def get_input_embeddings(self):
374
- return self.embed.token_embedding
375
 
376
  def set_input_embeddings(self, value):
377
- self.embed.token_embedding = value
378
 
379
  def get_output_embeddings(self):
380
  return self.lm_head
@@ -383,24 +345,18 @@ class TyneRoxModel(PreTrainedModel, GenerationMixin):
383
  self.lm_head = value
384
 
385
  def forward(
386
- self,
387
- input_ids: torch.LongTensor,
388
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
389
- labels: Optional[torch.LongTensor] = None,
390
- use_cache: bool = True,
391
- return_dict: bool = True,
392
- **kwargs
393
  ) -> Union[Tuple, CausalLMOutputWithPast]:
394
- # 1) Embeddings
395
- x = self.embed(input_ids) # [B, T, D]
396
-
397
- # 2) Transformer blocks with KV-cache
398
  x, new_past = self.transformer(x, past_key_values=past_key_values)
 
399
 
400
- # 3) Project to vocabulary logits
401
- logits = self.lm_head(x) # [B, T, V]
402
-
403
- # 4) Compute loss if labels provided
404
  loss = None
405
  if labels is not None:
406
  shift_logits = logits[:, :-1, :].contiguous()
@@ -411,7 +367,6 @@ class TyneRoxModel(PreTrainedModel, GenerationMixin):
411
  ignore_index=-100,
412
  )
413
 
414
- # 5) Return standardized output
415
  if not return_dict:
416
  output = (logits, new_past) if use_cache else (logits,)
417
  return ((loss,) + output) if loss is not None else output
@@ -429,21 +384,19 @@ class TyneRoxModel(PreTrainedModel, GenerationMixin):
429
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
430
  reordered = []
431
  for k, v in past_key_values:
432
- # ambos têm batch dim = dim 0
433
  reordered.append((k.index_select(0, beam_idx),
434
  v.index_select(0, beam_idx)))
435
  return tuple(reordered)
436
 
437
  def prepare_inputs_for_generation(
438
- self,
439
- input_ids: torch.LongTensor,
440
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
441
- **kwargs
442
  ) -> dict:
443
- # at generation time, only feed in the last token
444
  if past_key_values is not None:
445
  input_ids = input_ids[:, -1:].contiguous()
446
  return {
447
  "input_ids": input_ids,
448
  "past_key_values": past_key_values,
449
- }
 
9
 
10
 
11
  class PositionalEncoding(nn.Module):
12
+ def __init__(self, embed_dim, context_length):
 
 
 
 
 
 
 
 
 
13
  super().__init__()
14
+ if embed_dim % 2 != 0:
15
+ raise ValueError("embed_dim must be even for rotary")
 
 
 
 
 
 
16
  self.embed_dim = embed_dim
17
+ self._build_table(context_length)
18
 
19
+ def _build_table(self, length):
20
+ # inv_freq[j] = 1 / 10000^(2j/embed_dim)
21
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, self.embed_dim, 2).float() / self.embed_dim))
22
+ positions = torch.arange(length).float().unsqueeze(1) # [L,1]
23
+ freqs = positions * inv_freq.unsqueeze(0) # [L, D/2]
24
+ self.register_buffer('sin_table', freqs.sin(), persistent=True)
25
+ self.register_buffer('cos_table', freqs.cos(), persistent=True)
 
 
 
 
26
 
27
+ def apply_rotary(self, x, sin, cos):
28
+ # x: [B, T, D], sin/cos: [1, T, D/2]
29
+ x_pairs = x.view(*x.shape[:-1], -1, 2) # [..., D/2, 2]
30
+ x1, x2 = x_pairs[..., 0], x_pairs[..., 1]
31
+ y1 = x1 * cos - x2 * sin
32
+ y2 = x1 * sin + x2 * cos
33
+ x_rot = torch.stack([y1, y2], dim=-1) # [..., D/2, 2]
34
+ return x_rot.flatten(-2)
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  def get_alibi_slopes(n_heads: int) -> torch.Tensor:
39
  def _get_slopes(n):
40
  base = 2 ** (-8.0 / n)
41
  return torch.tensor([base ** (i + 1) for i in range(n)])
42
+
43
  if math.log2(n_heads).is_integer():
44
  return _get_slopes(n_heads)
45
  m = 2 ** math.floor(math.log2(n_heads))
 
47
  extra = _get_slopes(2 * m)[::2][: n_heads - m]
48
  return torch.cat([slopes, extra], dim=0)
49
 
50
+
51
  # -----------------------------------------------------------------------------
52
  # Feed-Forward
53
  # -----------------------------------------------------------------------------
 
65
  x_up, x_gate = x_fc1.chunk(2, dim=-1)
66
  return self.fc2(x_up * self.activation(x_gate))
67
 
68
+
69
  # -----------------------------------------------------------------------------
70
  # Attention-Free Transformer (AFT) Simple
71
  # -----------------------------------------------------------------------------
 
74
  def __init__(
75
  self,
76
  embed_dim: int,
77
+ max_position_embeddings: int,
78
  activation=torch.sigmoid,
79
  causal: bool = True,
80
  ):
 
83
  self.causal = causal
84
  self.activation = activation
85
 
86
+ # Rotary PE (dropout=0 para não afetar Q/K)
87
+ self.rotary = PositionalEncoding(embed_dim, max_position_embeddings)
88
+
89
  self.qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
90
  self.project = nn.Linear(embed_dim, embed_dim)
91
 
 
94
  x: torch.Tensor,
95
  past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
96
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
97
+ """
98
+ x: [B, T_new, D]
99
+ past_key_values: (K_past, V_past), cada um [B, T_past, D]
100
+ """
101
  B, T_new, D = x.shape
102
+ assert D == self.embed_dim, f"Embedding dimension mismatch: {D} != {self.embed_dim}"
 
103
 
104
+ # 1) Projeções lineares
105
+ qkv = self.qkv(x) # [B, T_new, 3*D]
106
+ Q, K_new, V_new = qkv.chunk(3, dim=-1)
107
 
108
+ # 2) Calcular sin/cos para as posições de Q/K_new
109
+ # Se houver cache, shift nas posições; senão, 0..T_new-1
110
  if past_key_values is not None:
111
  K_past, V_past = past_key_values
112
+ T_past = K_past.size(1)
 
113
  else:
114
+ T_past = 0
115
 
116
+ # obtem sin/cos para posições [T_past .. T_past+T_new-1]
117
+ device, dtype = Q.device, Q.dtype
118
+ pos = torch.arange(T_past, T_past + T_new, device=device)
119
+ sin = self.rotary.sin_table[pos].unsqueeze(0).to(dtype=dtype) # [1, T_new, D/2]
120
+ cos = self.rotary.cos_table[pos].unsqueeze(0).to(dtype=dtype)
121
+
122
+ # 3) Aplicar RoPE em Q e K_new
123
+ Q_rot = self.rotary.apply_rotary(Q, sin, cos) # [B, T_new, D]
124
+ K_new_rot = self.rotary.apply_rotary(K_new, sin, cos) # [B, T_new, D]
125
+
126
+ # 4) Concatena cache (já rotacionado) com K_new_rot
127
+ if past_key_values is not None:
128
+ K = torch.cat([K_past, K_new_rot], dim=1) # [B, T_all, D]
129
+ V = torch.cat([V_past, V_new], dim=1)
130
+ else:
131
+ K, V = K_new_rot, V_new
132
+
133
+ # 5) Atenção “free” sobre V
134
+ softmax_k = F.softmax(K, dim=1) # [B, T_all, D]
135
+ weighted_v = softmax_k * V # [B, T_all, D]
136
 
137
  if self.causal:
138
+ context = torch.cumsum(weighted_v, dim=1) # [B, T_all, D]
139
  else:
140
+ total = weighted_v.sum(dim=1, keepdim=True) # [B,1,D]
141
+ context = total.expand(-1, K.size(1), -1) # [B,T_all,D]
142
+
143
+ # 6) Seleciona apenas as posições novas
144
+ context_new = context[:, -T_new:, :] # [B, T_new, D]
145
 
146
+ # 7) Gating e projeção final
147
+ gate = self.activation(Q_rot) # [B, T_new, D]
148
+ Y = gate * context_new # [B, T_new, D]
149
+ Y = self.project(Y) # [B, T_new, D]
 
150
 
 
151
  return Y, (K, V)
152
 
153
+
154
  # -----------------------------------------------------------------------------
155
  # Flash Attention with ALiBi and KV-cache
156
  # -----------------------------------------------------------------------------
157
 
158
  class FlashAttention(nn.Module):
159
  def __init__(
160
+ self,
161
+ embed_dim: int,
162
+ num_heads: int,
163
+ window_size: int,
164
+ causal: bool = True,
165
+ qkv_bias: bool = False,
166
  ):
167
  super().__init__()
168
  assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
 
178
  self.register_buffer('alibi', get_alibi_slopes(num_heads))
179
 
180
  def forward(
181
+ self,
182
+ x: torch.Tensor,
183
+ past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
184
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
185
  # x: [B, T_new, D]
186
  B, T_new, _ = x.size()
 
203
  return_attn_probs=False,
204
  )
205
  # attn_out: [B, T_new, H, Dh]
206
+ out = attn_out.contiguous().view(B, T_new, -1) # [B, T_new, D]
207
+ y = self.out_proj(out) # [B, T_new, D]
208
 
209
  return y, (k, v)
210
 
211
+
212
  # -----------------------------------------------------------------------------
213
  # Transformer Blocks and Model
214
  # -----------------------------------------------------------------------------
 
217
  def __init__(self, config, att_global: bool = True):
218
  super().__init__()
219
  if att_global:
220
+ self.attn = AFTSimple(embed_dim=config.d_model, causal=config.causal, max_position_embeddings=config.max_position_embeddings)
221
  else:
222
  self.attn = FlashAttention(
223
  embed_dim=config.d_model,
 
226
  causal=config.causal,
227
  qkv_bias=True,
228
  )
229
+
230
  self.ff = nn.Sequential(
231
  FeedForward(config.d_model),
 
232
  )
233
  self.ln1 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
234
  self.ln2 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
235
  self.drop = nn.Dropout(config.dropout)
236
 
237
  def forward(
238
+ self,
239
+ x: torch.Tensor,
240
+ past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
241
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
242
  # Attention + residual
243
  residual = x
 
253
 
254
  return x, present
255
 
256
+
257
  class ResidualBlocks(nn.Module):
258
  def __init__(self, config):
259
  super().__init__()
 
266
  self.final_ln = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
267
 
268
  def forward(
269
+ self,
270
+ x: torch.Tensor,
271
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None
272
+ ) -> Tuple[torch.Tensor, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]:
273
  new_past = []
274
  for i, layer in enumerate(self.layers):
275
  pkv = None if past_key_values is None else past_key_values[i]
 
278
  x = self.final_ln(x)
279
  return x, tuple(new_past)
280
 
281
+
282
  # -----------------------------------------------------------------------------
283
  # Configuration and Model
284
  # -----------------------------------------------------------------------------
 
287
  model_type = "tynerox"
288
 
289
  def __init__(
290
+ self,
291
+ vocab_size: int = 30522,
292
+ context_length: int = 2048,
293
+ d_model: int = 1024,
294
+ num_heads: int = 16,
295
+ window_size: int = 512,
296
+ num_hidden_layers: int = 12,
297
+ causal: bool = True,
298
+ dropout: float = 0.1,
299
+ layer_norm_eps: float = 1e-5,
300
+ tie_word_embeddings: bool = False,
301
+ pad_token_id: int = 0,
302
+ **kwargs
303
  ):
304
  super().__init__(**kwargs)
305
  self.vocab_size = vocab_size
 
314
  self.tie_word_embeddings = tie_word_embeddings
315
  self.pad_token_id = pad_token_id
316
 
317
+
318
  class TyneRoxModel(PreTrainedModel, GenerationMixin):
319
  config_class = TyneRoxConfig
320
 
321
  def __init__(self, config: TyneRoxConfig):
322
  super().__init__(config)
323
+ self.scale = math.sqrt(config.d_model)
324
+ self.embed = nn.Embedding(config.vocab_size, config.d_model)
 
 
 
 
 
325
  self.transformer = ResidualBlocks(config)
326
  self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
327
+
328
+ # Correction #2: tie embeddings if requested
329
+ if config.tie_word_embeddings:
330
+ self.lm_head.weight = self.embed.weight
331
+
332
  self.post_init()
333
 
334
+ # Correction #3: fix get/set input embeddings to use .weight
335
  def get_input_embeddings(self):
336
+ return self.embed
337
 
338
  def set_input_embeddings(self, value):
339
+ self.embed = value
340
 
341
  def get_output_embeddings(self):
342
  return self.lm_head
 
345
  self.lm_head = value
346
 
347
  def forward(
348
+ self,
349
+ input_ids: torch.LongTensor,
350
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
351
+ labels: Optional[torch.LongTensor] = None,
352
+ use_cache: bool = True,
353
+ return_dict: bool = True,
354
+ **kwargs
355
  ) -> Union[Tuple, CausalLMOutputWithPast]:
356
+ x = self.embed(input_ids) * self.scale
 
 
 
357
  x, new_past = self.transformer(x, past_key_values=past_key_values)
358
+ logits = self.lm_head(x)
359
 
 
 
 
 
360
  loss = None
361
  if labels is not None:
362
  shift_logits = logits[:, :-1, :].contiguous()
 
367
  ignore_index=-100,
368
  )
369
 
 
370
  if not return_dict:
371
  output = (logits, new_past) if use_cache else (logits,)
372
  return ((loss,) + output) if loss is not None else output
 
384
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
385
  reordered = []
386
  for k, v in past_key_values:
 
387
  reordered.append((k.index_select(0, beam_idx),
388
  v.index_select(0, beam_idx)))
389
  return tuple(reordered)
390
 
391
  def prepare_inputs_for_generation(
392
+ self,
393
+ input_ids: torch.LongTensor,
394
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
395
+ **kwargs
396
  ) -> dict:
 
397
  if past_key_values is not None:
398
  input_ids = input_ids[:, -1:].contiguous()
399
  return {
400
  "input_ids": input_ids,
401
  "past_key_values": past_key_values,
402
+ }