davda54 commited on
Commit
7741265
·
verified ·
1 Parent(s): 4ba37ea

Update modeling_gptbert.py

Browse files
Files changed (1) hide show
  1. modeling_gptbert.py +32 -77
modeling_gptbert.py CHANGED
@@ -96,6 +96,7 @@ class CastedLinearIn(nn.Linear):
96
  class MultiCastedLinearOrthoIn(nn.Module):
97
  def __init__(self, in_features, out_features, bias):
98
  super().__init__()
 
99
  self.in_features = in_features
100
  self.out_features = out_features
101
 
@@ -178,15 +179,10 @@ class Embedding(nn.Module):
178
  def __init__(self, config: GptBertConfig):
179
  super().__init__()
180
 
181
- assert hasattr(config, "vocab_size"), "The config must have a vocab_size attribute!"
182
- assert hasattr(config, "hidden_size"), "The config must have a hidden_size attribute!"
183
- assert hasattr(config, "embedding_dropout_p"), "The model must have a embedding_dropout_p attribute!"
184
-
185
  self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
186
- self.word_norm = nn.LayerNorm(config.hidden_size, eps=config.word_norm_eps, elementwise_affine=False, bias=False)
187
  self.word_scale = nn.Parameter(torch.zeros(config.hidden_size))
188
-
189
- self.dropout = nn.Dropout(config.embedding_dropout_p)
190
 
191
  def forward(self, input_ids: torch.Tensor):
192
  word_embedding = self.word_embedding(input_ids)
@@ -200,9 +196,10 @@ class Classifier(nn.Module):
200
  def __init__(self, config: GptBertConfig, n_labels: int):
201
  super().__init__()
202
 
203
- self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_pre_norm_eps, elementwise_affine=config.classifier_pre_norm_affine)
204
  self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
205
- self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_post_norm_eps, elementwise_affine=config.classifier_post_norm_affine)
 
206
  self.emb2vocab = CastedLinearIn(config.hidden_size, n_labels, bias=True)
207
 
208
  def forward(self, x: torch.Tensor):
@@ -210,22 +207,13 @@ class Classifier(nn.Module):
210
  x = self.projection(x)
211
  x = gelu_new(x)
212
  x = self.post_norm(x.float()).type_as(x)
 
213
  x = self.emb2vocab(x)
214
  return x
215
 
216
 
217
- def flash_attention_forward(
218
- qkv: torch.Tensor,
219
- rotary_emb: UnpaddedRotaryEmbedding,
220
- cu_seqlens: torch.Tensor,
221
- max_seqlen: int,
222
- causal: bool,
223
- local_attention: Tuple[int, int],
224
- dropout_p: float,
225
- deterministic: bool,
226
- target_dtype: torch.dtype = torch.bfloat16,
227
- **_kwargs,
228
- ):
229
  qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
230
 
231
  convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
@@ -265,8 +253,8 @@ class SelfAttention(nn.Module):
265
  self.config = config
266
  self.layer_idx = layer_idx
267
 
268
- self.d_qk = config.d_qk
269
- self.d_v = config.d_v
270
  self.num_attention_heads = config.num_attention_heads
271
  self.num_kv_heads = config.num_kv_heads
272
  self.hidden_size = config.hidden_size
@@ -279,23 +267,21 @@ class SelfAttention(nn.Module):
279
  self.v_proj = CastedLinearIn(self.hidden_size, self.v_out_dim, bias=False)
280
  self.out_proj = CastedLinearIn(self.d_v*self.num_attention_heads, self.hidden_size, bias=False)
281
 
282
- self.pre_v_norm = nn.LayerNorm(config.hidden_size, eps=config.attention_pre_norm_eps, elementwise_affine=config.attention_pre_norm_affine)
283
- self.pre_qk_norm = nn.LayerNorm(config.hidden_size, eps=config.attention_pre_norm_eps, elementwise_affine=config.attention_pre_norm_affine)
284
- self.inter_norm = nn.LayerNorm(self.d_v * self.num_attention_heads, eps=config.attention_inter_norm_eps, elementwise_affine=config.attention_inter_norm_affine)
285
- self.q_norm = nn.LayerNorm(config.d_qk, eps=config.attention_pre_norm_eps, elementwise_affine=False, bias=False)
286
- self.k_norm = nn.LayerNorm(config.d_qk, eps=config.attention_pre_norm_eps, elementwise_affine=False, bias=False)
287
- self.k_scale = nn.Parameter(torch.ones(self.num_kv_heads, config.d_qk))
288
- self.q_scale = nn.Parameter(torch.ones(self.num_attention_heads, config.d_qk))
289
 
290
- self.dropout = nn.Dropout(config.attention_output_dropout_p)
291
- self.attention_dropout = config.attention_dropout if hasattr(config, "attention_dropout") else 0.0
292
- self.deterministic_flash_attn = getattr(config, "deterministic_flash_attn", False)
293
 
294
  theta = 160_000 if (layer_idx + 1) % config.short_long_ratio == 0 else 10_000
295
 
296
  # Initialize rotary embeddings based on whether FlashAttention is available
297
  if self.config._attn_implementation == "flash_attention_2":
298
- self.rope_embedding = UnpaddedRotaryEmbedding(dim=config.d_qk, base=theta, max_seqlen=config.max_sequence_length)
299
  else:
300
  self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
301
 
@@ -338,7 +324,7 @@ class SelfAttention(nn.Module):
338
  key=key,
339
  value=value,
340
  attn_mask=attention_mask,
341
- dropout_p=self.attention_dropout if self.training else 0.0,
342
  is_causal=self.is_causal
343
  )
344
  return output
@@ -394,8 +380,8 @@ class SelfAttention(nn.Module):
394
  max_seqlen,
395
  self.is_causal,
396
  local_attention,
397
- self.attention_dropout if self.training else 0.0,
398
- self.deterministic_flash_attn
399
  )
400
 
401
  # Reshape output back
@@ -434,12 +420,12 @@ class SelfAttention(nn.Module):
434
  class FeedForward(nn.Module):
435
  def __init__(self, config: GptBertConfig):
436
  super().__init__()
437
- self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.feed_forward_pre_norm_eps, elementwise_affine=config.feed_forward_pre_norm_affine)
438
  self.up_proj = MultiCastedLinearOrthoIn(config.hidden_size, [config.intermediate_size, config.intermediate_size], bias=False)
439
  self.activation = GeGLU()
440
- self.inter_norm = nn.LayerNorm(config.intermediate_size, eps=config.feed_forward_inter_norm_eps, elementwise_affine=config.feed_forward_inter_norm_affine)
441
  self.down_proj = CastedLinearIn(config.intermediate_size, config.hidden_size, bias=False)
442
- self.dropout = nn.Dropout(config.feed_forward_dropout_p)
443
 
444
  def forward(self, x: torch.Tensor):
445
  x = self.pre_norm(x.float()).type_as(x)
@@ -451,16 +437,10 @@ class FeedForward(nn.Module):
451
  return x
452
 
453
 
 
454
  class ApplyRotaryEmbUnpad(torch.autograd.Function):
455
  @staticmethod
456
- def forward(
457
- ctx,
458
- qkv,
459
- cos,
460
- sin,
461
- cu_seqlens: Optional[torch.Tensor] = None,
462
- max_seqlen: Optional[int] = None,
463
- ):
464
  # (total_nnz, 3, nheads, headdim)
465
  qkv = qkv.contiguous()
466
  total_nnz, _three, _nheads, headdim = qkv.shape
@@ -468,16 +448,7 @@ class ApplyRotaryEmbUnpad(torch.autograd.Function):
468
  # we get the same tensor
469
  # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
470
  qk = qkv[:, :2].view(total_nnz, -1, headdim)
471
- apply_rotary(
472
- qk,
473
- cos,
474
- sin,
475
- seqlen_offsets=0,
476
- cu_seqlens=cu_seqlens,
477
- max_seqlen=max_seqlen,
478
- interleaved=False,
479
- inplace=True,
480
- )
481
 
482
  ctx.save_for_backward(cos, sin, cu_seqlens)
483
  ctx.max_seqlen = max_seqlen
@@ -506,10 +477,12 @@ class ApplyRotaryEmbUnpad(torch.autograd.Function):
506
  return do, None, None, None, None, None, None
507
 
508
 
 
509
  def apply_rotary_unpadded(qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
510
  return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
511
 
512
 
 
513
  class UnpaddedRotaryEmbedding(RotaryEmbedding):
514
  def __init__(self, dim: int, base: float = 10000.0, max_seqlen: Optional[int] = None):
515
  super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=None, interleaved=False)
@@ -537,19 +510,7 @@ class RotaryPositionalEmbeddings(nn.Module):
537
  def __init__(self, config, theta: int):
538
  super().__init__()
539
 
540
- assert hasattr(config, "d_qk"), "The config must have a d_qk attribute!"
541
- assert hasattr(config, "max_sequence_length"), "The config must have a max_sequence_length attribute!"
542
-
543
- self.inv_freq: torch.Tensor
544
- self.cos_matrix: torch.Tensor
545
- self.sin_matrix: torch.Tensor
546
- head_size: int
547
- max_seq_len: int
548
- inv_freq: torch.Tensor
549
- pos: torch.Tensor
550
- embedding: torch.Tensor
551
-
552
- head_size = config.d_qk
553
  assert head_size % 2 == 0
554
  max_seq_len = config.max_sequence_length
555
 
@@ -561,12 +522,6 @@ class RotaryPositionalEmbeddings(nn.Module):
561
  self.register_buffer("sin_matrix", embedding.sin(), persistent=False)
562
 
563
  def forward(self, x: torch.Tensor):
564
- seq_len: int
565
- cos_matrix: torch.Tensor
566
- sin_matrix: torch.Tensor
567
- x_rotate_half: torch.Tensor
568
- out: torch.Tensor
569
-
570
  hidden_layer = x.float()
571
 
572
  seq_len = x.shape[2]
 
96
  class MultiCastedLinearOrthoIn(nn.Module):
97
  def __init__(self, in_features, out_features, bias):
98
  super().__init__()
99
+
100
  self.in_features = in_features
101
  self.out_features = out_features
102
 
 
179
  def __init__(self, config: GptBertConfig):
180
  super().__init__()
181
 
 
 
 
 
182
  self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
183
+ self.word_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
184
  self.word_scale = nn.Parameter(torch.zeros(config.hidden_size))
185
+ self.dropout = nn.Dropout(config.embedding_dropout)
 
186
 
187
  def forward(self, input_ids: torch.Tensor):
188
  word_embedding = self.word_embedding(input_ids)
 
196
  def __init__(self, config: GptBertConfig, n_labels: int):
197
  super().__init__()
198
 
199
+ self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
200
  self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
201
+ self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
202
+ self.dropout = nn.Dropout(config.classifier_dropout)
203
  self.emb2vocab = CastedLinearIn(config.hidden_size, n_labels, bias=True)
204
 
205
  def forward(self, x: torch.Tensor):
 
207
  x = self.projection(x)
208
  x = gelu_new(x)
209
  x = self.post_norm(x.float()).type_as(x)
210
+ x = self.dropout(x)
211
  x = self.emb2vocab(x)
212
  return x
213
 
214
 
215
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
216
+ def flash_attention_forward(qkv: torch.Tensor, rotary_emb: UnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, max_seqlen: int, causal: bool, local_attention: Tuple[int, int], dropout_p: float, deterministic: bool, target_dtype: torch.dtype = torch.bfloat16, **_kwargs):
 
 
 
 
 
 
 
 
 
 
217
  qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
218
 
219
  convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
 
253
  self.config = config
254
  self.layer_idx = layer_idx
255
 
256
+ self.d_qk = config.query_key_head_size
257
+ self.d_v = config.value_head_size
258
  self.num_attention_heads = config.num_attention_heads
259
  self.num_kv_heads = config.num_kv_heads
260
  self.hidden_size = config.hidden_size
 
267
  self.v_proj = CastedLinearIn(self.hidden_size, self.v_out_dim, bias=False)
268
  self.out_proj = CastedLinearIn(self.d_v*self.num_attention_heads, self.hidden_size, bias=False)
269
 
270
+ self.pre_v_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
271
+ self.pre_qk_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
272
+ self.inter_norm = nn.LayerNorm(self.d_v * self.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=False)
273
+ self.q_norm = nn.LayerNorm(self.d_qk, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
274
+ self.k_norm = nn.LayerNorm(self.d_qk, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
275
+ self.k_scale = nn.Parameter(torch.ones(self.num_kv_heads, self.d_qk))
276
+ self.q_scale = nn.Parameter(torch.ones(self.num_attention_heads, self.d_qk))
277
 
278
+ self.dropout = nn.Dropout(config.hidden_dropout)
 
 
279
 
280
  theta = 160_000 if (layer_idx + 1) % config.short_long_ratio == 0 else 10_000
281
 
282
  # Initialize rotary embeddings based on whether FlashAttention is available
283
  if self.config._attn_implementation == "flash_attention_2":
284
+ self.rope_embedding = UnpaddedRotaryEmbedding(dim=self.d_qk, base=theta, max_seqlen=config.max_sequence_length)
285
  else:
286
  self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
287
 
 
324
  key=key,
325
  value=value,
326
  attn_mask=attention_mask,
327
+ dropout_p=self.config.attention_dropout if self.training else 0.0,
328
  is_causal=self.is_causal
329
  )
330
  return output
 
380
  max_seqlen,
381
  self.is_causal,
382
  local_attention,
383
+ self.config.attention_dropout if self.training else 0.0,
384
+ self.config.deterministic_flash_attn
385
  )
386
 
387
  # Reshape output back
 
420
  class FeedForward(nn.Module):
421
  def __init__(self, config: GptBertConfig):
422
  super().__init__()
423
+ self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
424
  self.up_proj = MultiCastedLinearOrthoIn(config.hidden_size, [config.intermediate_size, config.intermediate_size], bias=False)
425
  self.activation = GeGLU()
426
+ self.inter_norm = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps, elementwise_affine=False)
427
  self.down_proj = CastedLinearIn(config.intermediate_size, config.hidden_size, bias=False)
428
+ self.dropout = nn.Dropout(config.hidden_dropout)
429
 
430
  def forward(self, x: torch.Tensor):
431
  x = self.pre_norm(x.float()).type_as(x)
 
437
  return x
438
 
439
 
440
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
441
  class ApplyRotaryEmbUnpad(torch.autograd.Function):
442
  @staticmethod
443
+ def forward(ctx, qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
 
 
 
 
 
 
 
444
  # (total_nnz, 3, nheads, headdim)
445
  qkv = qkv.contiguous()
446
  total_nnz, _three, _nheads, headdim = qkv.shape
 
448
  # we get the same tensor
449
  # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
450
  qk = qkv[:, :2].view(total_nnz, -1, headdim)
451
+ apply_rotary(qk, cos, sin, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=False, inplace=True)
 
 
 
 
 
 
 
 
 
452
 
453
  ctx.save_for_backward(cos, sin, cu_seqlens)
454
  ctx.max_seqlen = max_seqlen
 
477
  return do, None, None, None, None, None, None
478
 
479
 
480
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
481
  def apply_rotary_unpadded(qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
482
  return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
483
 
484
 
485
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
486
  class UnpaddedRotaryEmbedding(RotaryEmbedding):
487
  def __init__(self, dim: int, base: float = 10000.0, max_seqlen: Optional[int] = None):
488
  super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=None, interleaved=False)
 
510
  def __init__(self, config, theta: int):
511
  super().__init__()
512
 
513
+ head_size = config.query_key_head_size
 
 
 
 
 
 
 
 
 
 
 
 
514
  assert head_size % 2 == 0
515
  max_seq_len = config.max_sequence_length
516
 
 
522
  self.register_buffer("sin_matrix", embedding.sin(), persistent=False)
523
 
524
  def forward(self, x: torch.Tensor):
 
 
 
 
 
 
525
  hidden_layer = x.float()
526
 
527
  seq_len = x.shape[2]