davda54 commited on
Commit
4d4d26d
·
verified ·
1 Parent(s): 28e46d7

Add FlashAttention + unpadding support

Browse files
Files changed (1) hide show
  1. modeling_gptbert.py +364 -347
modeling_gptbert.py CHANGED
@@ -5,11 +5,12 @@ import torch.nn as nn
5
  from torch.nn import functional as F
6
  from torch import _softmax_backward_data as _softmax_backward_data
7
 
8
- from functools import partial
9
 
10
  from .configuration_gptbert import GptBertConfig
11
  from transformers.modeling_utils import PreTrainedModel
12
  from transformers.activations import gelu_new
 
13
  from transformers.modeling_outputs import (
14
  MaskedLMOutput,
15
  MultipleChoiceModelOutput,
@@ -22,82 +23,71 @@ from transformers.modeling_outputs import (
22
  import math
23
  from typing import TYPE_CHECKING, Optional, Union, Tuple, List
24
 
25
- try:
26
- from torch.nn.attention.flex_attention import flex_attention, create_block_mask
27
- except ImportError:
28
- pass
29
 
30
 
31
- class ModelOutput:
 
 
 
 
 
32
 
33
- def __init__(
34
- self,
35
- logits: torch.Tensor | None = None,
36
- loss: torch.Tensor | float | None = None,
37
- perplexity: torch.Tensor | float | None = None,
38
- accuracy: float | None = None,
39
- z_loss: torch.Tensor | float | None = None,
40
- **kwargs
41
- ):
42
- self.logits: torch.Tensor | None
43
- self.loss: torch.Tensor | float | None
44
- self.perplexity: torch.Tensor | float | None
45
- self.accuracy: float | None
46
- self.z_loss: torch.Tensor | float | None
47
 
48
- self.logits = logits
49
- self.loss = loss
50
- self.perplexity = perplexity
51
- self.accuracy = accuracy
52
- self.z_loss = z_loss
53
 
54
- for attr, value in kwargs.items():
55
- setattr(self, attr, value)
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- class CastedLinear(nn.Linear):
59
 
 
60
  def __init__(self, in_features, out_features, bias):
61
  super().__init__(in_features, out_features, bias=bias)
62
 
63
- def reset_parameters(self) -> None:
64
- std: float = math.sqrt(2.0 / (self.in_features + self.out_features))
65
- nn.init.trunc_normal_(self.weight, mean=0.0, std=std, a=-2*std, b=2*std)
66
-
67
  def forward(self, x):
68
  return F.linear(x, self.weight.type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
69
 
70
 
71
  class CastedLinearIn(nn.Linear):
72
-
73
  def __init__(self, in_features, out_features, bias):
74
  super().__init__(in_features, out_features, bias=bias)
75
  self.scale = nn.Parameter(torch.ones(in_features))
76
 
77
- def reset_parameters(self) -> None:
78
- std: float = math.sqrt(2.0 / (self.in_features + self.out_features))
79
- nn.init.trunc_normal_(self.weight, mean=0.0, std=std, a=-2*std, b=2*std)
80
-
81
  def forward(self, x):
82
  return F.linear(x, (self.weight * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
83
 
84
 
85
  class CastedLinearOut(nn.Linear):
86
-
87
  def __init__(self, in_features, out_features, bias):
88
  super().__init__(in_features, out_features, bias=bias)
89
  self.scale = nn.Parameter(torch.ones(out_features))
90
 
91
- def reset_parameters(self) -> None:
92
- std: float = math.sqrt(2.0 / (self.in_features + self.out_features))
93
- nn.init.trunc_normal_(self.weight, mean=0.0, std=std, a=-2*std, b=2*std)
94
-
95
  def forward(self, x):
96
  return F.linear(x, (self.scale.unsqueeze(1) * self.weight).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
97
 
98
 
99
  class MultiCastedLinearOrtho(nn.Module):
100
-
101
  def __init__(self, in_features, out_features, bias):
102
  super().__init__()
103
  self.in_features = in_features
@@ -112,19 +102,11 @@ class MultiCastedLinearOrtho(nn.Module):
112
  else:
113
  self.bias = self.register_parameter("bias", None)
114
 
115
- self.reset_parameters()
116
-
117
- def reset_parameters(self) -> None:
118
- for i, weight in enumerate(self.weights):
119
- std: float = math.sqrt(2.0 / (self.in_features + self.out_features[i]))
120
- nn.init.trunc_normal_(weight, mean=0.0, std=std, a=-2*std, b=2*std)
121
-
122
  def forward(self, x):
123
  return F.linear(x, torch.cat([weight for weight in self.weights], dim=0).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
124
 
125
 
126
  class MultiCastedLinearOrthoIn(nn.Module):
127
-
128
  def __init__(self, in_features, out_features, bias):
129
  super().__init__()
130
  self.in_features = in_features
@@ -141,23 +123,14 @@ class MultiCastedLinearOrthoIn(nn.Module):
141
 
142
  self.scale = nn.Parameter(torch.ones(in_features))
143
 
144
- self.reset_parameters()
145
-
146
- def reset_parameters(self) -> None:
147
- for weight in self.weights:
148
- std = 0.5 * (self.in_features ** -0.5)
149
- bound = (3 ** 0.5) * std
150
- with torch.no_grad():
151
- weight.uniform_(-bound, bound)
152
-
153
  def forward(self, x):
154
  return F.linear(x, (torch.cat([weight for weight in self.weights], dim=0) * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
155
 
156
 
157
  class MultiCastedLinearOrthoOut(nn.Module):
158
-
159
  def __init__(self, in_features, out_features, bias):
160
  super().__init__()
 
161
  self.in_features = in_features
162
  self.out_features = out_features
163
 
@@ -172,15 +145,6 @@ class MultiCastedLinearOrthoOut(nn.Module):
172
 
173
  self.scale = nn.Parameter(torch.ones(sum(out_features)))
174
 
175
- self.reset_parameters()
176
-
177
- def reset_parameters(self) -> None:
178
- for weight in self.weights:
179
- std = 0.5 * (self.in_features ** -0.5)
180
- bound = (3 ** 0.5) * std
181
- with torch.no_grad():
182
- weight.uniform_(-bound, bound)
183
-
184
  def forward(self, x):
185
  return F.linear(x, (self.scale.unsqueeze(1) * torch.cat([weight for weight in self.weights], dim=0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
186
 
@@ -188,15 +152,12 @@ class MultiCastedLinearOrthoOut(nn.Module):
188
  class GeGLU(nn.Module):
189
  def forward(self, x):
190
  x, gate = x.chunk(2, dim=-1)
191
- x = x * gelu_new(gate)
192
- return x
193
 
194
 
195
  class MaskedSoftmax(torch.autograd.Function):
196
  @staticmethod
197
- def forward(ctx, x: torch.Tensor, mask: torch.BoolTensor, dim: int) -> torch.Tensor:
198
- ctx.dim: int
199
-
200
  ctx.dim = dim
201
  x.masked_fill_(mask, float('-inf'))
202
  x = torch.softmax(x, ctx.dim)
@@ -205,47 +166,34 @@ class MaskedSoftmax(torch.autograd.Function):
205
  return x
206
 
207
  @staticmethod
208
- def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]:
209
- output: torch.Tensor
210
-
211
  output, = ctx.saved_tensors
212
  inputGrad: torch.Tensor = _softmax_backward_data(grad_output, output, ctx.dim, output.dtype)
213
  return inputGrad, None, None
214
 
215
 
216
  class Encoder(nn.Module):
217
-
218
- def __init__(self, config) -> None:
219
  super().__init__()
220
 
221
- self.layers: nn.ModuleList[Layer]
222
-
223
  self.layers = nn.ModuleList([Layer(config, i) for i in range(config.num_layers)])
224
-
225
- for i, layer in enumerate(self.layers):
226
- for weight in layer.mlp.up_proj.weights:
227
- weight.data *= math.sqrt(1.0 / (2.0 * (i + 1)))
228
- layer.mlp.down_proj.weight.data *= math.sqrt(1.0 / (2.0 * (i + 1)))
229
-
230
  self.short_long_ratio = config.short_long_ratio
231
 
232
- def set_window_length(self, config) -> None:
233
  for i, layer in enumerate(self.layers):
234
  if (i+1) % self.short_long_ratio == 0:
235
- layer.set_window_length(config.window_length, config.not_flex)
236
  else:
237
- layer.set_window_length(256, config.not_flex)
238
-
239
- def forward(self, hidden_layer: torch.Tensor, embeddings: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
240
- hidden_layer: List[torch.Tensor]
241
- attention_probs: List[torch.Tensor]
242
 
 
243
  hidden_states = []
244
  attention_probs = []
245
  v1 = None
 
246
 
247
  for layer in self.layers:
248
- hidden_layer, v1, attention_p = layer(hidden_layer, embeddings, v1, mask)
249
  hidden_states.append(hidden_layer)
250
  attention_probs.append(attention_p)
251
 
@@ -253,29 +201,22 @@ class Encoder(nn.Module):
253
 
254
 
255
  class Layer(nn.Module):
256
-
257
- def __init__(self, config, layer_idx: int) -> None:
258
  super().__init__()
259
 
260
- self.attention: SelfAttention
261
- self.mlp: FeedForward
262
-
263
  self.attention = SelfAttention(config, layer_idx)
264
  self.mlp = FeedForward(config)
265
  self.lambdas = nn.Parameter(torch.tensor([0., 0., 1., 0., 1., 0.]))
266
 
267
- def set_window_length(self, window_length: int, not_flex: bool) -> None:
268
- self.attention.set_window_length(window_length, not_flex)
269
-
270
- def forward(self, hidden_layer: torch.Tensor, embeddings: torch.Tensor, v1: torch.Tensor | None, mask: torch.Tensor | None = None) -> Tuple[torch.Tensor, torch.Tensor]:
271
- output: torch.Tensor
272
- attention_p: torch.Tensor
273
 
 
274
  attention_output = (1 - self.lambdas[0]) * hidden_layer + self.lambdas[0] * embeddings
275
  qk_layer = (1 - self.lambdas[1]) * hidden_layer + self.lambdas[1] * embeddings
276
  mlp_layer = F.softplus(self.lambdas[2]) * ((1 - self.lambdas[3]) * hidden_layer + self.lambdas[3] * embeddings)
277
 
278
- attention_output, v1, attention_p = self.attention(attention_output, qk_layer, v1, mask)
279
  mlp_layer = mlp_layer + attention_output
280
  hidden_layer = F.softplus(self.lambdas[4]) * ((1 - self.lambdas[5]) * hidden_layer + self.lambdas[5] * embeddings)
281
  output = hidden_layer + attention_output + self.mlp(mlp_layer)
@@ -284,97 +225,92 @@ class Layer(nn.Module):
284
 
285
 
286
  class Embedding(nn.Module):
287
-
288
- def __init__(self, config) -> None:
289
  super().__init__()
290
 
291
  assert hasattr(config, "vocab_size"), "The config must have a vocab_size attribute!"
292
  assert hasattr(config, "hidden_size"), "The config must have a hidden_size attribute!"
293
  assert hasattr(config, "embedding_dropout_p"), "The model must have a embedding_dropout_p attribute!"
294
 
295
- self.word_embedding: nn.Embedding
296
- self.word_norm: nn.LayerNorm
297
- self.dropout: nn.Dropout
298
-
299
  self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
300
  self.word_norm = nn.LayerNorm(config.hidden_size, eps=config.word_norm_eps, elementwise_affine=False, bias=False)
301
  self.word_scale = nn.Parameter(torch.zeros(config.hidden_size))
302
 
303
  self.dropout = nn.Dropout(config.embedding_dropout_p)
304
 
305
- self.initialize(config.hidden_size, config.vocab_size)
306
-
307
- @torch.no_grad()
308
- def initialize(self, hidden_size: int, vocab_size: int) -> None:
309
- std: float
310
-
311
- std = math.sqrt(2.0 / (hidden_size + vocab_size))
312
- nn.init.trunc_normal_(self.word_embedding.weight, mean=0.0, std=std, a=-2*std, b=2*std)
313
-
314
- def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
315
- word_embedding: torch.Tensor
316
-
317
  word_embedding = self.word_embedding(input_ids)
318
  word_embedding = self.word_norm(word_embedding)
319
- word_embedding = (word_embedding * (self.word_scale + 1.0).unsqueeze(0).unsqueeze(0))
320
 
321
  return self.dropout(word_embedding)
322
 
323
 
324
  class MaskClassifier(nn.Module):
325
-
326
- def __init__(self, config, embedding_weights: nn.Parameter) -> None:
327
  super().__init__()
328
 
329
- self.projection: CastedLinear
330
- self.emb2vocab: CastedLinear
331
- self.pre_norm: nn.LayerNorm
332
- self.post_norm: nn.LayerNorm
333
-
334
  self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_pre_norm_eps, elementwise_affine=config.classifier_pre_norm_affine)
335
  self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
336
  self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_post_norm_eps, elementwise_affine=config.classifier_post_norm_affine)
337
  self.emb2vocab = CastedLinearIn(config.hidden_size, config.vocab_size, bias=True)
338
 
339
- self.initialize(config.hidden_size, config.vocab_size, embedding_weights)
340
-
341
- @torch.no_grad()
342
- def initialize(self, hidden_size: int, vocab_size: int, embedding_weights: nn.Parameter) -> None:
343
- proj_std: float = math.sqrt(2.0 / (hidden_size + 4*hidden_size))
344
-
345
- nn.init.trunc_normal_(self.projection.weight, mean=0.0, std=proj_std, a=-2*proj_std, b=2*proj_std)
346
- self.emb2vocab.weight = embedding_weights
347
- self.emb2vocab.bias.zero_()
348
-
349
- def project(self, hidden_layer: torch.Tensor) -> torch.Tensor:
350
- projection: torch.Tensor
351
-
352
- projection = self.projection(hidden_layer)
353
- projection = gelu_new(projection)
354
- projection = self.post_norm(projection)
355
-
356
- return projection
357
-
358
- def calculate_output(self, hidden_layer: torch.Tensor) -> torch.Tensor:
359
- return self.emb2vocab(hidden_layer)
360
-
361
- def forward(self, hidden_layer: torch.Tensor, labels: torch.Tensor | None = None) -> torch.Tensor:
362
- output: torch.Tensor
363
-
364
- if labels is not None:
365
- hidden_layer = torch.index_select(hidden_layer.flatten(0, 1), 0, torch.nonzero(labels.flatten() != -100).squeeze())
366
-
367
- hidden_layer = self.pre_norm(hidden_layer)
368
- hidden_layer = self.project(hidden_layer)
369
- output = self.calculate_output(hidden_layer)
370
-
371
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
 
373
 
374
  class SelfAttention(nn.Module):
375
-
376
- def __init__(self, config, layer_idx) -> None:
377
  super().__init__()
 
 
 
 
378
  self.d_qk = config.d_qk
379
  self.d_v = config.d_v
380
  self.num_attention_heads = config.num_attention_heads
@@ -398,72 +334,59 @@ class SelfAttention(nn.Module):
398
  self.q_scale = nn.Parameter(torch.ones(self.num_attention_heads, config.d_qk))
399
 
400
  self.dropout = nn.Dropout(config.attention_output_dropout_p)
 
 
401
 
402
  theta = 160_000 if (layer_idx + 1) % config.short_long_ratio == 0 else 10_000
403
 
404
- self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
405
- self.scale: float = 1.0 / math.sqrt(self.d_qk)
 
 
 
 
 
 
 
 
 
406
 
 
407
  self.dropout = nn.Dropout(config.attention_dropout if hasattr(config, "attention_dropout") else 0.0)
408
 
409
  self.lambdas = nn.Parameter(torch.tensor([0.5]))
410
 
411
- self.initialize()
412
-
413
  self.sequence_length = config.max_sequence_length
414
  self.is_causal = config.is_decoder
415
- self.not_flex = config.not_flex
416
-
417
- @torch.no_grad()
418
- def initialize(self) -> None:
419
- std: float = math.sqrt(2.0 / (self.hidden_size + 4*self.hidden_size))
420
- for weight in self.qk_proj.weights:
421
- nn.init.trunc_normal_(weight, mean=0.0, std=std, a=-2*std, b=2*std)
422
- nn.init.trunc_normal_(self.v_proj.weight, mean=0.0, std=std, a=2*std, b=2*std)
423
- self.out_proj.weight.data.zero_()
424
 
425
- def set_window_length(self, window_length: int, not_flex: bool) -> None:
426
- self.window_length: int = window_length
427
- if not not_flex:
428
- self.block_mask = self.create_block_mask(window_length)
429
 
430
- def causal_mask_mode(self, window_length, b, _, q_idx, kv_idx):
431
- return (q_idx >= kv_idx) & ((q_idx - kv_idx) < window_length)
432
-
433
- def bidirectional_mask_mode(self, window_length, b, _, q_idx, kv_idx):
434
- return ((q_idx - kv_idx) < window_length) & ((kv_idx - q_idx) < window_length)
435
-
436
- def create_block_mask(self, window_length: int) -> torch.Tensor:
437
  if self.is_causal:
438
- return create_block_mask(
439
- partial(self.causal_mask_mode, self.window_length),
440
- 1, 1, self.sequence_length, self.sequence_length, device=self.k_scale.device
441
- )
442
  else:
443
- return create_block_mask(
444
- partial(self.bidirectional_mask_mode, self.window_length),
445
- 1, 1, self.sequence_length, self.sequence_length, device=self.k_scale.device
446
- )
447
-
448
- def attention_operation(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, padding_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
449
- attention_scores: torch.Tensor
450
- attention_probabilities: torch.Tensor
451
- batch_size: int
452
- query_length: int
453
- key_length: int
454
 
 
 
455
  batch_size, _, query_length, _ = query.size()
456
  _, _, key_length, _ = key.size()
457
 
458
- if self.is_causal:
459
- window_mask = ~torch.ones(query_length, key_length, dtype=torch.bool, device=self.k_scale.device).tril().triu(diagonal=-self.window_length).view(1, 1, query_length, key_length)
460
- else:
461
- window_mask = ~torch.ones(query_length, key_length, dtype=torch.bool, device=self.k_scale.device).tril(diagonal=self.window_length).triu(diagonal=-self.window_length).view(1, 1, query_length, key_length)
462
-
463
- if padding_mask is not None:
464
- attention_mask = padding_mask | window_mask
465
- else:
466
- attention_mask = window_mask
467
 
468
  attention_scores = torch.bmm(query.flatten(0, 1), key.transpose(-1, -2).flatten(0, 1)) * self.scale # shape: [B*H, T, T]
469
  attention_scores = attention_scores.view(batch_size, self.num_attention_heads, query_length, key_length)
@@ -476,120 +399,226 @@ class SelfAttention(nn.Module):
476
 
477
  return value, attention_probabilities.detach()
478
 
479
- def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, mask: torch.Tensor | None = None, doc_ids: torch.Tensor | None = None) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
 
480
  hidden_layer = self.pre_v_norm(hidden_layer)
481
  qk_layer = self.pre_qk_norm(qk_layer)
482
 
483
  query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1)
484
  value = self.v_proj(hidden_layer)
485
 
486
- query_length: int = hidden_layer.size(0)
487
- key_length: int = hidden_layer.size(0)
488
- batch_size: int = hidden_layer.size(1)
 
 
489
 
490
- query = query.reshape(query_length, batch_size, self.num_attention_heads, self.d_qk).permute(1, 2, 0, 3) # shape: [B, H, T, D]
491
- key = key.reshape(key_length, batch_size, self.num_kv_heads, self.d_qk).permute(1, 2, 0, 3) # shape: [B, H, T, D]
492
- value = value.reshape(key_length, batch_size, self.num_kv_heads, self.d_qk).permute(1, 2, 0, 3) # shape: [B, H, T, D]
493
 
494
- query, key = ((self.q_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.q_norm(query.float())).type_as(query), ((self.k_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.k_norm(key.float())).type_as(key)
 
 
495
 
496
- if v1 is None:
497
- v1 = value
498
- value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
 
500
- query = self.rope_embedding(query)
501
- key = self.rope_embedding(key)
 
502
 
503
- if self.not_flex:
504
- output, attention_probabilities = self.attention_operation(query, key, value, mask)
505
  else:
506
- def document_score_mod(score, b, _, q_idx, kv_idx):
507
- return torch.where(doc_ids[q_idx] == doc_ids[kv_idx], score, -float("inf"))
508
-
509
- if self.is_causal:
510
- block_mask = create_block_mask(
511
- partial(self.causal_mask_mode, self.window_length),
512
- 1, 1, query_length, key_length, device=self.k_scale.device
513
- )
514
- else:
515
- block_mask = create_block_mask(
516
- partial(self.bidirectional_mask_mode, self.window_length),
517
- 1, 1, query_length, key_length, device=self.k_scale.device
518
- )
519
 
520
- output = flex_attention(
521
- query, key, value, block_mask=block_mask, enable_gqa=True
522
- )
523
- attention_probabilities = None
524
 
525
- output = output.permute(2, 0, 1, 3).flatten(2, 3) # shape: [T, B, H*D]
526
- output = self.inter_norm(output)
527
- output = self.out_proj(output)
528
 
529
- return self.dropout(output), v1, attention_probabilities
 
 
530
 
 
 
 
531
 
532
- class FeedForward(nn.Module):
 
 
 
 
533
 
534
- def __init__(self, config) -> None:
535
- super().__init__()
536
 
537
- self.up_proj: CastedLinear
538
- self.down_proj: CastedLinear
539
- self.pre_norm: nn.LayerNorm
540
- self.inter_norm: nn.LayerNorm
541
- self.activation: GeGLU
542
- self.dropout: nn.Dropout
543
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
  self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.feed_forward_pre_norm_eps, elementwise_affine=config.feed_forward_pre_norm_affine)
545
  self.up_proj = MultiCastedLinearOrthoIn(config.hidden_size, [config.intermediate_size, config.intermediate_size], bias=False)
546
  self.activation = GeGLU()
547
  self.inter_norm = nn.LayerNorm(config.intermediate_size, eps=config.feed_forward_inter_norm_eps, elementwise_affine=config.feed_forward_inter_norm_affine)
548
  self.down_proj = CastedLinearIn(config.intermediate_size, config.hidden_size, bias=False)
549
  self.dropout = nn.Dropout(config.feed_forward_dropout_p)
 
 
 
 
 
 
 
 
550
 
551
- self.initialize(config.hidden_size)
552
 
553
- @torch.no_grad()
554
- def initialize(self, hidden_size: int) -> None:
555
- std: float = math.sqrt(2.0 / (5*hidden_size))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556
 
557
- for weight in self.up_proj.weights:
558
- nn.init.trunc_normal_(weight, mean=0.0, std=std, a=-2*std, b=2*std)
559
- self.down_proj.weight.data.zero_()
560
 
561
- def up_project(self, hidden_layer: torch.Tensor) -> torch.Tensor:
562
- hidden_layer = self.pre_norm(hidden_layer)
563
- return self.up_proj(hidden_layer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
 
565
- def activate(self, projection: torch.Tensor) -> torch.Tensor:
566
- activated_projection: torch.Tensor
567
 
568
- activated_projection = self.activation(projection)
569
- activated_projection = self.inter_norm(activated_projection.float()).type_as(projection)
570
 
571
- return activated_projection
 
 
 
 
 
 
 
572
 
573
- def down_project(self, activated_projection: torch.Tensor) -> torch.Tensor:
574
- output: torch.Tensor
575
 
576
- output = self.down_proj(activated_projection)
 
 
 
577
 
578
- return self.dropout(output)
 
579
 
580
- def forward(self, hidden_layer: torch.Tensor) -> torch.Tensor:
581
- output: torch.Tensor
 
582
 
583
- output = self.up_project(hidden_layer)
584
- output = self.activate(output)
585
- output = self.down_project(output)
 
 
 
 
586
 
587
- return output
588
 
589
 
590
  class RotaryPositionalEmbeddings(nn.Module):
591
-
592
- def __init__(self, config, theta: int) -> None:
593
  super().__init__()
594
 
595
  assert hasattr(config, "d_qk"), "The config must have a d_qk attribute!"
@@ -615,7 +644,7 @@ class RotaryPositionalEmbeddings(nn.Module):
615
  self.register_buffer("cos_matrix", embedding.cos(), persistent=False)
616
  self.register_buffer("sin_matrix", embedding.sin(), persistent=False)
617
 
618
- def forward(self, x: torch.Tensor) -> torch.Tensor:
619
  seq_len: int
620
  cos_matrix: torch.Tensor
621
  sin_matrix: torch.Tensor
@@ -647,18 +676,17 @@ class RotaryPositionalEmbeddings(nn.Module):
647
 
648
  class GptBertPreTrainedModel(PreTrainedModel):
649
  config_class = GptBertConfig
650
- supports_gradient_checkpointing = False
651
-
652
- def _set_gradient_checkpointing(self, module, value=False):
653
- raise NotImplementedError("Gradient checkpointing is not supported by this model")
654
 
655
  def _init_weights(self, module):
656
  pass
657
 
658
 
659
  class GptBertModel(GptBertPreTrainedModel):
660
-
661
- def __init__(self, config, add_mlm_layer=False, **kwargs):
662
  super().__init__(config, **kwargs)
663
  self.config = config
664
  self.hidden_size = config.hidden_size
@@ -680,7 +708,8 @@ class GptBertModel(GptBertPreTrainedModel):
680
  def get_contextualized_embeddings(
681
  self,
682
  input_ids: Optional[torch.Tensor] = None,
683
- attention_mask: Optional[torch.Tensor] = None
 
684
  ) -> List[torch.Tensor]:
685
  if input_ids is not None:
686
  input_shape = input_ids.size()
@@ -690,35 +719,36 @@ class GptBertModel(GptBertPreTrainedModel):
690
  batch_size, seq_length = input_shape
691
  device = input_ids.device
692
 
693
- # if attention_mask is None:
694
- # attention_mask = torch.zeros(batch_size, seq_length, dtype=torch.bool, device=device)
695
- if attention_mask is not None:
696
- attention_mask = ~attention_mask.bool()
697
-
698
- if len(attention_mask.size()) == 2:
699
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
700
- elif len(attention_mask.size()) == 3:
701
- attention_mask = attention_mask.unsqueeze(1)
702
 
703
- if self.config.is_decoder:
704
- attention_mask = attention_mask | torch.triu(torch.ones(seq_length, seq_length, dtype=torch.bool, device=device), 1).unsqueeze(0).unsqueeze(0)
 
 
 
705
 
706
- static_embeddings = self.embedding(input_ids.t())
707
- contextualized_embeddings, attention_probs = self.encoder(static_embeddings, static_embeddings, attention_mask)
708
- contextualized_embeddings = [e.transpose(0, 1) for e in contextualized_embeddings]
709
  last_layer = contextualized_embeddings[-1]
710
- contextualized_embeddings = [contextualized_embeddings[0]] + [
711
- contextualized_embeddings[i] - contextualized_embeddings[i - 1]
712
- for i in range(1, len(contextualized_embeddings))
713
- ]
 
 
 
 
 
714
  return last_layer, contextualized_embeddings, attention_probs
715
 
716
  def forward(
717
  self,
718
  input_ids: Optional[torch.Tensor] = None,
719
  attention_mask: Optional[torch.Tensor] = None,
720
- token_type_ids: Optional[torch.Tensor] = None,
721
- position_ids: Optional[torch.Tensor] = None,
722
  output_hidden_states: Optional[bool] = None,
723
  output_attentions: Optional[bool] = None,
724
  return_dict: Optional[bool] = None,
@@ -726,7 +756,9 @@ class GptBertModel(GptBertPreTrainedModel):
726
  ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
727
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
728
 
729
- sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
 
 
730
 
731
  if not return_dict:
732
  return (
@@ -745,7 +777,7 @@ class GptBertModel(GptBertPreTrainedModel):
745
  class GptBertForMaskedLM(GptBertModel):
746
  _keys_to_ignore_on_load_unexpected = ["head"]
747
 
748
- def __init__(self, config, **kwargs):
749
  super().__init__(config, add_mlm_layer=True, **kwargs)
750
 
751
  def get_output_embeddings(self):
@@ -799,7 +831,7 @@ class GptBertForMaskedLM(GptBertModel):
799
 
800
 
801
  class Classifier(nn.Module):
802
- def __init__(self, config, num_labels: int):
803
  super().__init__()
804
 
805
  drop_out = getattr(config, "cls_dropout", None)
@@ -826,34 +858,19 @@ class Classifier(nn.Module):
826
  nn.init.trunc_normal_(self.emb2vocab.weight, mean=0.0, std=proj_std, a=-2*proj_std, b=2*proj_std)
827
  self.emb2vocab.bias.zero_()
828
 
829
- def project(self, hidden_layer: torch.Tensor) -> torch.Tensor:
830
- projection: torch.Tensor
831
-
832
- projection = self.pre_norm(hidden_layer)
833
- projection = self.dropout(projection)
834
- projection = self.projection(projection)
835
- projection = gelu_new(projection)
836
- projection = self.post_norm(projection)
837
-
838
- return projection
839
-
840
- def calculate_output(self, hidden_layer: torch.Tensor) -> torch.Tensor:
841
- return self.emb2vocab(hidden_layer)
842
-
843
- def forward(self, hidden_layer: torch.Tensor) -> torch.Tensor:
844
- output: torch.Tensor
845
- projection: torch.Tensor
846
-
847
- projection = self.project(hidden_layer)
848
- output = self.calculate_output(projection)
849
-
850
- return output
851
 
852
 
853
  class GptBertForCausalLM(GptBertModel):
854
  _keys_to_ignore_on_load_unexpected = ["head"]
855
 
856
- def __init__(self, config, **kwargs):
857
  config.is_decoder = True
858
  super().__init__(config, add_mlm_layer=True, **kwargs)
859
 
@@ -978,7 +995,7 @@ class GptBertForCausalLM(GptBertModel):
978
  class GptBertForSequenceClassification(GptBertModel):
979
  _keys_to_ignore_on_load_unexpected = ["classifier"]
980
 
981
- def __init__(self, config, **kwargs):
982
  super().__init__(config, add_mlm_layer=False, **kwargs)
983
 
984
  self.num_labels = config.num_labels
@@ -1043,7 +1060,7 @@ class GptBertForSequenceClassification(GptBertModel):
1043
  class GptBertForTokenClassification(GptBertModel):
1044
  _keys_to_ignore_on_load_unexpected = ["classifier"]
1045
 
1046
- def __init__(self, config, **kwargs):
1047
  super().__init__(config, add_mlm_layer=False, **kwargs)
1048
 
1049
  self.num_labels = config.num_labels
@@ -1090,7 +1107,7 @@ class GptBertForTokenClassification(GptBertModel):
1090
  class GptBertForQuestionAnswering(GptBertModel):
1091
  _keys_to_ignore_on_load_unexpected = ["classifier"]
1092
 
1093
- def __init__(self, config, **kwargs):
1094
  super().__init__(config, add_mlm_layer=False, **kwargs)
1095
 
1096
  self.num_labels = config.num_labels
@@ -1157,7 +1174,7 @@ class GptBertForQuestionAnswering(GptBertModel):
1157
  class GptBertForMultipleChoice(GptBertModel):
1158
  _keys_to_ignore_on_load_unexpected = ["classifier"]
1159
 
1160
- def __init__(self, config, **kwargs):
1161
  super().__init__(config, add_mlm_layer=False, **kwargs)
1162
 
1163
  self.num_labels = getattr(config, "num_labels", 2)
 
5
  from torch.nn import functional as F
6
  from torch import _softmax_backward_data as _softmax_backward_data
7
 
8
+ from functools import partial, lru_cache
9
 
10
  from .configuration_gptbert import GptBertConfig
11
  from transformers.modeling_utils import PreTrainedModel
12
  from transformers.activations import gelu_new
13
+ from transformers.utils import is_flash_attn_2_available, is_flax_available
14
  from transformers.modeling_outputs import (
15
  MaskedLMOutput,
16
  MultipleChoiceModelOutput,
 
23
  import math
24
  from typing import TYPE_CHECKING, Optional, Union, Tuple, List
25
 
26
+ if is_flash_attn_2_available():
27
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
28
+ from flash_attn.layers.rotary import RotaryEmbedding
29
+ from flash_attn.ops.triton.rotary import apply_rotary
30
 
31
 
32
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
33
+ def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor):
34
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
35
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
36
+ max_seqlen_in_batch = int(seqlens_in_batch.max().item())
37
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
38
 
39
+ if input_ids.dim() == 2:
40
+ unpadded_inputs = input_ids.flatten()[indices]
41
+ else:
42
+ batch_size, sequence_length, *rest = input_ids.shape
43
+ shape = batch_size * sequence_length
44
+ unpadded_inputs = input_ids.view(shape, *rest)[indices]
 
 
 
 
 
 
 
 
45
 
46
+ return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch
 
 
 
 
47
 
 
 
48
 
49
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
50
+ def _pad_output(input_ids: torch.Tensor, indices: torch.Tensor, batch_size: int, sequence_length: int) -> torch.Tensor:
51
+ if input_ids.dim() == 1:
52
+ output = torch.zeros(batch_size * sequence_length, dtype=input_ids.dtype, device=input_ids.device)
53
+ output[indices] = input_ids
54
+ padded_inputs = output.view(batch_size, sequence_length)
55
+ else:
56
+ _, *rest = input_ids.shape
57
+ output = torch.zeros(batch_size * sequence_length, *rest, dtype=input_ids.dtype, device=input_ids.device)
58
+ output[indices] = input_ids
59
+ padded_inputs = output.view(batch_size, sequence_length, *rest)
60
+
61
+ return padded_inputs
62
 
 
63
 
64
+ class CastedLinear(nn.Linear):
65
  def __init__(self, in_features, out_features, bias):
66
  super().__init__(in_features, out_features, bias=bias)
67
 
 
 
 
 
68
  def forward(self, x):
69
  return F.linear(x, self.weight.type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
70
 
71
 
72
  class CastedLinearIn(nn.Linear):
 
73
  def __init__(self, in_features, out_features, bias):
74
  super().__init__(in_features, out_features, bias=bias)
75
  self.scale = nn.Parameter(torch.ones(in_features))
76
 
 
 
 
 
77
  def forward(self, x):
78
  return F.linear(x, (self.weight * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
79
 
80
 
81
  class CastedLinearOut(nn.Linear):
 
82
  def __init__(self, in_features, out_features, bias):
83
  super().__init__(in_features, out_features, bias=bias)
84
  self.scale = nn.Parameter(torch.ones(out_features))
85
 
 
 
 
 
86
  def forward(self, x):
87
  return F.linear(x, (self.scale.unsqueeze(1) * self.weight).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
88
 
89
 
90
  class MultiCastedLinearOrtho(nn.Module):
 
91
  def __init__(self, in_features, out_features, bias):
92
  super().__init__()
93
  self.in_features = in_features
 
102
  else:
103
  self.bias = self.register_parameter("bias", None)
104
 
 
 
 
 
 
 
 
105
  def forward(self, x):
106
  return F.linear(x, torch.cat([weight for weight in self.weights], dim=0).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
107
 
108
 
109
  class MultiCastedLinearOrthoIn(nn.Module):
 
110
  def __init__(self, in_features, out_features, bias):
111
  super().__init__()
112
  self.in_features = in_features
 
123
 
124
  self.scale = nn.Parameter(torch.ones(in_features))
125
 
 
 
 
 
 
 
 
 
 
126
  def forward(self, x):
127
  return F.linear(x, (torch.cat([weight for weight in self.weights], dim=0) * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
128
 
129
 
130
  class MultiCastedLinearOrthoOut(nn.Module):
 
131
  def __init__(self, in_features, out_features, bias):
132
  super().__init__()
133
+
134
  self.in_features = in_features
135
  self.out_features = out_features
136
 
 
145
 
146
  self.scale = nn.Parameter(torch.ones(sum(out_features)))
147
 
 
 
 
 
 
 
 
 
 
148
  def forward(self, x):
149
  return F.linear(x, (self.scale.unsqueeze(1) * torch.cat([weight for weight in self.weights], dim=0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
150
 
 
152
  class GeGLU(nn.Module):
153
  def forward(self, x):
154
  x, gate = x.chunk(2, dim=-1)
155
+ return x * gelu_new(gate)
 
156
 
157
 
158
  class MaskedSoftmax(torch.autograd.Function):
159
  @staticmethod
160
+ def forward(ctx, x: torch.Tensor, mask: torch.BoolTensor, dim: int):
 
 
161
  ctx.dim = dim
162
  x.masked_fill_(mask, float('-inf'))
163
  x = torch.softmax(x, ctx.dim)
 
166
  return x
167
 
168
  @staticmethod
169
+ def backward(ctx, grad_output: torch.Tensor):
 
 
170
  output, = ctx.saved_tensors
171
  inputGrad: torch.Tensor = _softmax_backward_data(grad_output, output, ctx.dim, output.dtype)
172
  return inputGrad, None, None
173
 
174
 
175
  class Encoder(nn.Module):
176
+ def __init__(self, config: GptBertConfig):
 
177
  super().__init__()
178
 
 
 
179
  self.layers = nn.ModuleList([Layer(config, i) for i in range(config.num_layers)])
 
 
 
 
 
 
180
  self.short_long_ratio = config.short_long_ratio
181
 
182
+ def set_window_length(self, config: GptBertConfig):
183
  for i, layer in enumerate(self.layers):
184
  if (i+1) % self.short_long_ratio == 0:
185
+ layer.set_window_length(config.window_length)
186
  else:
187
+ layer.set_window_length(256)
 
 
 
 
188
 
189
+ def forward(self, hidden_layer: torch.Tensor, padding_info):
190
  hidden_states = []
191
  attention_probs = []
192
  v1 = None
193
+ embeddings = hidden_layer
194
 
195
  for layer in self.layers:
196
+ hidden_layer, v1, attention_p = layer(hidden_layer, embeddings, v1, padding_info)
197
  hidden_states.append(hidden_layer)
198
  attention_probs.append(attention_p)
199
 
 
201
 
202
 
203
  class Layer(nn.Module):
204
+ def __init__(self, config: GptBertConfig, layer_idx: int):
 
205
  super().__init__()
206
 
 
 
 
207
  self.attention = SelfAttention(config, layer_idx)
208
  self.mlp = FeedForward(config)
209
  self.lambdas = nn.Parameter(torch.tensor([0., 0., 1., 0., 1., 0.]))
210
 
211
+ def set_window_length(self, window_length: int):
212
+ self.attention.set_window_length(window_length)
 
 
 
 
213
 
214
+ def forward(self, hidden_layer: torch.Tensor, embeddings: torch.Tensor, v1: torch.Tensor | None, padding_info):
215
  attention_output = (1 - self.lambdas[0]) * hidden_layer + self.lambdas[0] * embeddings
216
  qk_layer = (1 - self.lambdas[1]) * hidden_layer + self.lambdas[1] * embeddings
217
  mlp_layer = F.softplus(self.lambdas[2]) * ((1 - self.lambdas[3]) * hidden_layer + self.lambdas[3] * embeddings)
218
 
219
+ attention_output, v1, attention_p = self.attention(attention_output, qk_layer, v1, padding_info)
220
  mlp_layer = mlp_layer + attention_output
221
  hidden_layer = F.softplus(self.lambdas[4]) * ((1 - self.lambdas[5]) * hidden_layer + self.lambdas[5] * embeddings)
222
  output = hidden_layer + attention_output + self.mlp(mlp_layer)
 
225
 
226
 
227
  class Embedding(nn.Module):
228
+ def __init__(self, config: GptBertConfig):
 
229
  super().__init__()
230
 
231
  assert hasattr(config, "vocab_size"), "The config must have a vocab_size attribute!"
232
  assert hasattr(config, "hidden_size"), "The config must have a hidden_size attribute!"
233
  assert hasattr(config, "embedding_dropout_p"), "The model must have a embedding_dropout_p attribute!"
234
 
 
 
 
 
235
  self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
236
  self.word_norm = nn.LayerNorm(config.hidden_size, eps=config.word_norm_eps, elementwise_affine=False, bias=False)
237
  self.word_scale = nn.Parameter(torch.zeros(config.hidden_size))
238
 
239
  self.dropout = nn.Dropout(config.embedding_dropout_p)
240
 
241
+ def forward(self, input_ids: torch.Tensor):
 
 
 
 
 
 
 
 
 
 
 
242
  word_embedding = self.word_embedding(input_ids)
243
  word_embedding = self.word_norm(word_embedding)
244
+ word_embedding = word_embedding * (self.word_scale + 1.0)
245
 
246
  return self.dropout(word_embedding)
247
 
248
 
249
  class MaskClassifier(nn.Module):
250
+ def __init__(self, config: GptBertConfig, embedding_weights: nn.Parameter):
 
251
  super().__init__()
252
 
 
 
 
 
 
253
  self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_pre_norm_eps, elementwise_affine=config.classifier_pre_norm_affine)
254
  self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
255
  self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_post_norm_eps, elementwise_affine=config.classifier_post_norm_affine)
256
  self.emb2vocab = CastedLinearIn(config.hidden_size, config.vocab_size, bias=True)
257
 
258
+ def forward(self, x: torch.Tensor):
259
+ x = self.pre_norm(x)
260
+ x = self.projection(x)
261
+ x = gelu_new(x)
262
+ x = self.post_norm(x)
263
+ return self.emb2vocab(x)
264
+
265
+
266
+ def flash_attention_forward(
267
+ qkv: torch.Tensor,
268
+ rotary_emb: UnpaddedRotaryEmbedding,
269
+ cu_seqlens: torch.Tensor,
270
+ max_seqlen: int,
271
+ local_attention: Tuple[int, int],
272
+ dropout_p: float,
273
+ deterministic: bool,
274
+ target_dtype: torch.dtype = torch.bfloat16,
275
+ **_kwargs,
276
+ ):
277
+ qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
278
+
279
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
280
+ if convert_dtype:
281
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
282
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
283
+ orig_dtype = qkv.dtype
284
+ qkv = qkv.to(target_dtype)
285
+
286
+ attn = flash_attn_varlen_qkvpacked_func(
287
+ qkv,
288
+ cu_seqlens=cu_seqlens,
289
+ max_seqlen=max_seqlen,
290
+ dropout_p=dropout_p,
291
+ deterministic=deterministic,
292
+ window_size=local_attention,
293
+ )
294
+ attn = attn.to(orig_dtype) # type: ignore
295
+ else:
296
+ attn = flash_attn_varlen_qkvpacked_func(
297
+ qkv,
298
+ cu_seqlens=cu_seqlens,
299
+ max_seqlen=max_seqlen,
300
+ dropout_p=dropout_p,
301
+ deterministic=deterministic,
302
+ window_size=local_attention,
303
+ )
304
+ return attn
305
 
306
 
307
  class SelfAttention(nn.Module):
308
+ def __init__(self, config: GptBertConfig, layer_idx: int):
 
309
  super().__init__()
310
+
311
+ self.config = config
312
+ self.layer_idx = layer_idx
313
+
314
  self.d_qk = config.d_qk
315
  self.d_v = config.d_v
316
  self.num_attention_heads = config.num_attention_heads
 
334
  self.q_scale = nn.Parameter(torch.ones(self.num_attention_heads, config.d_qk))
335
 
336
  self.dropout = nn.Dropout(config.attention_output_dropout_p)
337
+ self.attention_dropout = config.attention_dropout if hasattr(config, "attention_dropout") else 0.0
338
+ self.deterministic_flash_attn = getattr(config, "deterministic_flash_attn", False)
339
 
340
  theta = 160_000 if (layer_idx + 1) % config.short_long_ratio == 0 else 10_000
341
 
342
+ # Initialize rotary embeddings based on whether FlashAttention is available
343
+ if is_flash_attn_2_available():
344
+ self.rope_embedding = UnpaddedRotaryEmbedding(
345
+ dim=config.d_qk,
346
+ base=theta,
347
+ max_seqlen=config.max_sequence_length,
348
+ device=None,
349
+ dtype=None
350
+ )
351
+ else:
352
+ self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
353
 
354
+ self.scale = 1.0 / math.sqrt(self.d_qk)
355
  self.dropout = nn.Dropout(config.attention_dropout if hasattr(config, "attention_dropout") else 0.0)
356
 
357
  self.lambdas = nn.Parameter(torch.tensor([0.5]))
358
 
 
 
359
  self.sequence_length = config.max_sequence_length
360
  self.is_causal = config.is_decoder
361
+ self.window_length = None
 
 
 
 
 
 
 
 
362
 
363
+ def set_window_length(self, window_length: int):
364
+ self.window_length = window_length
 
 
365
 
366
+ @lru_cache(maxsize=32)
367
+ def _get_window_mask(self, query_length: int, key_length: int, device: torch.device):
368
+ """Create and cache window attention mask."""
 
 
 
 
369
  if self.is_causal:
370
+ mask = torch.ones(query_length, key_length, dtype=torch.bool, device=device)
371
+ mask = ~mask.tril().triu(diagonal=-self.window_length)
 
 
372
  else:
373
+ mask = torch.ones(query_length, key_length, dtype=torch.bool, device=device)
374
+ mask = ~mask.tril(diagonal=self.window_length).triu(diagonal=-self.window_length)
375
+ return mask.view(1, 1, query_length, key_length)
 
 
 
 
 
 
 
 
376
 
377
+ def attention_operation(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, padding_mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
378
+ """Standard attention computation with masking."""
379
  batch_size, _, query_length, _ = query.size()
380
  _, _, key_length, _ = key.size()
381
 
382
+ # Use cached window mask
383
+ with torch.no_grad():
384
+ window_mask = self._get_window_mask(query_length, key_length, query.device)
385
+
386
+ if padding_mask is not None:
387
+ attention_mask = padding_mask | window_mask
388
+ else:
389
+ attention_mask = window_mask
 
390
 
391
  attention_scores = torch.bmm(query.flatten(0, 1), key.transpose(-1, -2).flatten(0, 1)) * self.scale # shape: [B*H, T, T]
392
  attention_scores = attention_scores.view(batch_size, self.num_attention_heads, query_length, key_length)
 
399
 
400
  return value, attention_probabilities.detach()
401
 
402
+ def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, padding_info)]:
403
+ # Get original shape info
404
+ if is_flash_attn_2_available() and isinstance(padding_info, tuple):
405
+ # Unpadded case
406
+ indices, cu_seqlens, max_seqlen = padding_info
407
+ total_seqlen = hidden_layer.size(0)
408
+ batch_size = cu_seqlens.size(0) - 1
409
+ else:
410
+ # Padded case
411
+ batch_size, seq_length = hidden_layer.size(0), hidden_layer.size(1)
412
+ hidden_layer = hidden_layer.transpose(0, 1) # [seq_len, batch_size, hidden_size]
413
+ qk_layer = qk_layer.transpose(0, 1)
414
+
415
+
416
  hidden_layer = self.pre_v_norm(hidden_layer)
417
  qk_layer = self.pre_qk_norm(qk_layer)
418
 
419
  query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1)
420
  value = self.v_proj(hidden_layer)
421
 
422
+ if is_flash_attn_2_available() and isinstance(padding_info, tuple):
423
+ # Reshape for FlashAttention: (total_seqlen, num_heads, head_dim)
424
+ query = query.view(total_seqlen, self.num_attention_heads, self.d_qk)
425
+ key = key.view(total_seqlen, self.num_kv_heads, self.d_qk)
426
+ value = value.view(total_seqlen, self.num_kv_heads, self.d_v)
427
 
428
+ # Apply layer norm and scaling
429
+ query = ((self.q_scale + 1.0).unsqueeze(0) * self.q_norm(query.float())).type_as(query)
430
+ key = ((self.k_scale + 1.0).unsqueeze(0) * self.k_norm(key.float())).type_as(key)
431
 
432
+ if v1 is None:
433
+ v1 = value
434
+ value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
435
 
436
+ # Prepare qkv for FlashAttention
437
+ if self.num_kv_heads == self.num_attention_heads:
438
+ # Standard MHA
439
+ qkv = torch.stack([query, key, value], dim=1) # (total_seqlen, 3, num_heads, head_dim)
440
+ else:
441
+ # GQA case - need to repeat k,v heads
442
+ num_rep = self.num_attention_heads // self.num_kv_heads
443
+ key = key.repeat_interleave(num_rep, dim=1)
444
+ value = value.repeat_interleave(num_rep, dim=1)
445
+ qkv = torch.stack([query, key, value], dim=1)
446
+
447
+ # Determine window size for local attention
448
+ if self.window_length is not None and self.window_length > 0:
449
+ if self.is_causal:
450
+ local_attention = (self.window_length - 1, 0)
451
+ else:
452
+ local_attention = (self.window_length - 1, self.window_length - 1)
453
+ else:
454
+ local_attention = (-1, -1)
455
+
456
+ # Apply FlashAttention
457
+ output = flash_attention_forward(
458
+ qkv,
459
+ self.rope_embedding,
460
+ cu_seqlens,
461
+ max_seqlen,
462
+ local_attention,
463
+ self.attention_dropout if self.training else 0.0,
464
+ self.deterministic_flash_attn
465
+ )
466
 
467
+ # Reshape output back
468
+ output = output.view(total_seqlen, self.d_v * self.num_attention_heads)
469
+ attention_probabilities = None
470
 
 
 
471
  else:
472
+ # Standard attention path
473
+ query_length = hidden_layer.size(0)
474
+ key_length = hidden_layer.size(0)
 
 
 
 
 
 
 
 
 
 
475
 
476
+ query = query.reshape(query_length, batch_size, self.num_attention_heads, self.d_qk).permute(1, 2, 0, 3)
477
+ key = key.reshape(key_length, batch_size, self.num_kv_heads, self.d_qk).permute(1, 2, 0, 3)
478
+ value = value.reshape(key_length, batch_size, self.num_kv_heads, self.d_v).permute(1, 2, 0, 3)
 
479
 
480
+ query = ((self.q_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.q_norm(query.float())).type_as(query)
481
+ key = ((self.k_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.k_norm(key.float())).type_as(key)
 
482
 
483
+ if v1 is None:
484
+ v1 = value
485
+ value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
486
 
487
+ # Apply rotary embeddings
488
+ query = self.rope_embedding(query)
489
+ key = self.rope_embedding(key)
490
 
491
+ # Handle GQA for standard attention
492
+ if self.num_kv_heads != self.num_attention_heads:
493
+ num_rep = self.num_attention_heads // self.num_kv_heads
494
+ key = key.repeat_interleave(num_rep, dim=1)
495
+ value = value.repeat_interleave(num_rep, dim=1)
496
 
497
+ output, attention_probabilities = self.attention_operation(query, key, value, padding_info if not isinstance(padding_info, tuple) else None)
498
+ output = output.permute(2, 0, 1, 3).flatten(2, 3) # shape: [T, B, H*D]
499
 
500
+ output = self.inter_norm(output)
501
+ output = self.out_proj(output)
 
 
 
 
502
 
503
+ # Handle output padding if necessary
504
+ if is_flash_attn_2_available() and isinstance(padding_info, tuple):
505
+ # Already in correct format for unpadded
506
+ pass
507
+ else:
508
+ # Transpose back to [batch_size, seq_len, hidden_size]
509
+ output = output.transpose(0, 1)
510
+
511
+ return self.dropout(output), v1, attention_probabilities
512
+
513
+ class FeedForward(nn.Module):
514
+ def __init__(self, config: GptBertConfig):
515
+ super().__init__()
516
  self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.feed_forward_pre_norm_eps, elementwise_affine=config.feed_forward_pre_norm_affine)
517
  self.up_proj = MultiCastedLinearOrthoIn(config.hidden_size, [config.intermediate_size, config.intermediate_size], bias=False)
518
  self.activation = GeGLU()
519
  self.inter_norm = nn.LayerNorm(config.intermediate_size, eps=config.feed_forward_inter_norm_eps, elementwise_affine=config.feed_forward_inter_norm_affine)
520
  self.down_proj = CastedLinearIn(config.intermediate_size, config.hidden_size, bias=False)
521
  self.dropout = nn.Dropout(config.feed_forward_dropout_p)
522
+
523
+ def forward(self, x: torch.Tensor):
524
+ x = self.pre_norm(x)
525
+ x = self.up_proj(x)
526
+ x = self.activation(x)
527
+ x = self.inter_norm(x.float()).type_as(x)
528
+ x = self.down_proj(x)
529
+ return self.dropout(x)
530
 
 
531
 
532
+ class ApplyRotaryEmbUnpad(torch.autograd.Function):
533
+ @staticmethod
534
+ def forward(
535
+ ctx,
536
+ qkv,
537
+ cos,
538
+ sin,
539
+ cu_seqlens: Optional[torch.Tensor] = None,
540
+ max_seqlen: Optional[int] = None,
541
+ ):
542
+ # (total_nnz, 3, nheads, headdim)
543
+ qkv = qkv.contiguous()
544
+ total_nnz, _three, _nheads, headdim = qkv.shape
545
+ # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
546
+ # we get the same tensor
547
+ # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
548
+ qk = qkv[:, :2].view(total_nnz, -1, headdim)
549
+ apply_rotary(
550
+ qk,
551
+ cos,
552
+ sin,
553
+ seqlen_offsets=0,
554
+ cu_seqlens=cu_seqlens,
555
+ max_seqlen=max_seqlen,
556
+ interleaved=False,
557
+ inplace=True,
558
+ )
559
 
560
+ ctx.save_for_backward(cos, sin, cu_seqlens)
561
+ ctx.max_seqlen = max_seqlen
562
+ return qkv
563
 
564
+ @staticmethod
565
+ def backward(ctx, do):
566
+ cos, sin, cu_seqlens = ctx.saved_tensors
567
+ do = do.contiguous()
568
+ total_nnz, _three, _nheads, headdim = do.shape
569
+ # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
570
+ # we get the same tensor
571
+ dqk = do[:, :2].view(total_nnz, -1, headdim)
572
+ apply_rotary(
573
+ dqk,
574
+ cos,
575
+ sin,
576
+ seqlen_offsets=0,
577
+ cu_seqlens=cu_seqlens,
578
+ max_seqlen=ctx.max_seqlen,
579
+ interleaved=False,
580
+ inplace=True,
581
+ conjugate=True,
582
+ )
583
 
584
+ return do, None, None, None, None, None, None
 
585
 
 
 
586
 
587
+ def apply_rotary_unpadded(
588
+ qkv,
589
+ cos,
590
+ sin,
591
+ cu_seqlens: Optional[torch.Tensor] = None,
592
+ max_seqlen: Optional[int] = None,
593
+ ):
594
+ return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
595
 
 
 
596
 
597
+ class UnpaddedRotaryEmbedding(RotaryEmbedding):
598
+ def __init__(self, dim: int, base: float = 10000.0, max_seqlen: Optional[int] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
599
+ super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=device, interleaved=False)
600
+ self.max_seqlen = max_seqlen
601
 
602
+ if max_seqlen is not None and device is not None and dtype is not None:
603
+ self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
604
 
605
+ def forward(self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: Optional[int] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
606
+ if max_seqlen is not None:
607
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
608
 
609
+ qkv = apply_rotary_unpadded(
610
+ qkv,
611
+ self._cos_cached,
612
+ self._sin_cached,
613
+ cu_seqlens=cu_seqlens,
614
+ max_seqlen=max_seqlen,
615
+ )
616
 
617
+ return qkv
618
 
619
 
620
  class RotaryPositionalEmbeddings(nn.Module):
621
+ def __init__(self, config, theta: int):
 
622
  super().__init__()
623
 
624
  assert hasattr(config, "d_qk"), "The config must have a d_qk attribute!"
 
644
  self.register_buffer("cos_matrix", embedding.cos(), persistent=False)
645
  self.register_buffer("sin_matrix", embedding.sin(), persistent=False)
646
 
647
+ def forward(self, x: torch.Tensor):
648
  seq_len: int
649
  cos_matrix: torch.Tensor
650
  sin_matrix: torch.Tensor
 
676
 
677
  class GptBertPreTrainedModel(PreTrainedModel):
678
  config_class = GptBertConfig
679
+ supports_gradient_checkpointing = True
680
+ _supports_flash_attn_2 = True
681
+ _supports_sdpa = True
682
+ _supports_flex_attn = False
683
 
684
  def _init_weights(self, module):
685
  pass
686
 
687
 
688
  class GptBertModel(GptBertPreTrainedModel):
689
+ def __init__(self, config: GptBertConfig, add_mlm_layer=False, **kwargs):
 
690
  super().__init__(config, **kwargs)
691
  self.config = config
692
  self.hidden_size = config.hidden_size
 
708
  def get_contextualized_embeddings(
709
  self,
710
  input_ids: Optional[torch.Tensor] = None,
711
+ attention_mask: Optional[torch.Tensor] = None,
712
+ output_hidden_states: Optional[bool] = None
713
  ) -> List[torch.Tensor]:
714
  if input_ids is not None:
715
  input_shape = input_ids.size()
 
719
  batch_size, seq_length = input_shape
720
  device = input_ids.device
721
 
722
+ if attention_mask is None:
723
+ attention_mask = torch.ones(batch_size, seq_length, dtype=torch.bool, device=device)
724
+ elif attention_mask is not None:
725
+ if len(attention_mask.size()) != 2:
726
+ raise ValueError("Only attention mask with two dimensions is supported now.")
 
 
 
 
727
 
728
+ if is_flash_attn_2_available():
729
+ input_ids, indices, cu_seqlens, max_seqlen_in_batch = _unpad_input(input_ids, attention_mask)
730
+ padding_info = (indices, cu_seqlens, max_seqlen_in_batch)
731
+ else:
732
+ padding_info = attention_mask
733
 
734
+ static_embeddings = self.embedding(input_ids)
735
+ contextualized_embeddings, attention_probs = self.encoder(static_embeddings, padding_info)
 
736
  last_layer = contextualized_embeddings[-1]
737
+
738
+ # Pad output if using FlashAttention
739
+ if is_flash_attn_2_available():
740
+ last_layer = _pad_output(last_layer, indices, batch_size, seq_length)
741
+ if output_hidden_states:
742
+ contextualized_embeddings = [_pad_output(layer, indices, batch_size, seq_length) for layer in contextualized_embeddings]
743
+ else:
744
+ contextualized_embeddings = None
745
+
746
  return last_layer, contextualized_embeddings, attention_probs
747
 
748
  def forward(
749
  self,
750
  input_ids: Optional[torch.Tensor] = None,
751
  attention_mask: Optional[torch.Tensor] = None,
 
 
752
  output_hidden_states: Optional[bool] = None,
753
  output_attentions: Optional[bool] = None,
754
  return_dict: Optional[bool] = None,
 
756
  ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
757
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
758
 
759
+ sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(
760
+ input_ids, attention_mask, output_hidden_states
761
+ )
762
 
763
  if not return_dict:
764
  return (
 
777
  class GptBertForMaskedLM(GptBertModel):
778
  _keys_to_ignore_on_load_unexpected = ["head"]
779
 
780
+ def __init__(self, config: GptBertConfig, **kwargs):
781
  super().__init__(config, add_mlm_layer=True, **kwargs)
782
 
783
  def get_output_embeddings(self):
 
831
 
832
 
833
  class Classifier(nn.Module):
834
+ def __init__(self, config: GptBertConfig, num_labels: int):
835
  super().__init__()
836
 
837
  drop_out = getattr(config, "cls_dropout", None)
 
858
  nn.init.trunc_normal_(self.emb2vocab.weight, mean=0.0, std=proj_std, a=-2*proj_std, b=2*proj_std)
859
  self.emb2vocab.bias.zero_()
860
 
861
+ def forward(self, x: torch.Tensor):
862
+ x = self.pre_norm(x)
863
+ x = self.dropout(x)
864
+ x = self.projection(x)
865
+ x = gelu_new(x)
866
+ x = self.post_norm(x)
867
+ return self.emb2vocab(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
868
 
869
 
870
  class GptBertForCausalLM(GptBertModel):
871
  _keys_to_ignore_on_load_unexpected = ["head"]
872
 
873
+ def __init__(self, config: GptBertConfig, **kwargs):
874
  config.is_decoder = True
875
  super().__init__(config, add_mlm_layer=True, **kwargs)
876
 
 
995
  class GptBertForSequenceClassification(GptBertModel):
996
  _keys_to_ignore_on_load_unexpected = ["classifier"]
997
 
998
+ def __init__(self, config: GptBertConfig, **kwargs):
999
  super().__init__(config, add_mlm_layer=False, **kwargs)
1000
 
1001
  self.num_labels = config.num_labels
 
1060
  class GptBertForTokenClassification(GptBertModel):
1061
  _keys_to_ignore_on_load_unexpected = ["classifier"]
1062
 
1063
+ def __init__(self, config: GptBertConfig, **kwargs):
1064
  super().__init__(config, add_mlm_layer=False, **kwargs)
1065
 
1066
  self.num_labels = config.num_labels
 
1107
  class GptBertForQuestionAnswering(GptBertModel):
1108
  _keys_to_ignore_on_load_unexpected = ["classifier"]
1109
 
1110
+ def __init__(self, config: GptBertConfig, **kwargs):
1111
  super().__init__(config, add_mlm_layer=False, **kwargs)
1112
 
1113
  self.num_labels = config.num_labels
 
1174
  class GptBertForMultipleChoice(GptBertModel):
1175
  _keys_to_ignore_on_load_unexpected = ["classifier"]
1176
 
1177
+ def __init__(self, config: GptBertConfig, **kwargs):
1178
  super().__init__(config, add_mlm_layer=False, **kwargs)
1179
 
1180
  self.num_labels = getattr(config, "num_labels", 2)