ammarnasr commited on
Commit
ef3dbad
·
verified ·
1 Parent(s): 79b5fd1

Upload model

Browse files
Files changed (3) hide show
  1. config.json +0 -1
  2. configuration_gpt2mimo.py +1 -2
  3. modeling_gpt2mimo.py +121 -5
config.json CHANGED
@@ -31,6 +31,5 @@
31
  "torch_dtype": "float32",
32
  "transformers_version": "4.41.1",
33
  "use_cache": true,
34
- "attn_implementation":"eager",
35
  "vocab_size": 50257
36
  }
 
31
  "torch_dtype": "float32",
32
  "transformers_version": "4.41.1",
33
  "use_cache": true,
 
34
  "vocab_size": 50257
35
  }
configuration_gpt2mimo.py CHANGED
@@ -159,7 +159,6 @@ class GPT2MIMOConfig(PretrainedConfig):
159
  eos_token_id=50256,
160
  scale_attn_by_inverse_layer_idx=False,
161
  reorder_and_upcast_attn=False,
162
- attn_implementation="eager",
163
  **kwargs,
164
  ):
165
  self.vocab_size = vocab_size
@@ -187,4 +186,4 @@ class GPT2MIMOConfig(PretrainedConfig):
187
  self.bos_token_id = bos_token_id
188
  self.eos_token_id = eos_token_id
189
 
190
- super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id,attn_implementation=attn_implementation, **kwargs)
 
159
  eos_token_id=50256,
160
  scale_attn_by_inverse_layer_idx=False,
161
  reorder_and_upcast_attn=False,
 
162
  **kwargs,
163
  ):
164
  self.vocab_size = vocab_size
 
186
  self.bos_token_id = bos_token_id
187
  self.eos_token_id = eos_token_id
188
 
189
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
modeling_gpt2mimo.py CHANGED
@@ -4,6 +4,7 @@ from typing import Optional, Tuple, Union
4
 
5
  import torch
6
  import torch.utils.checkpoint
 
7
  from torch import nn
8
  from torch.nn import CrossEntropyLoss
9
 
@@ -15,10 +16,10 @@ from transformers.modeling_outputs import (
15
  )
16
  from transformers.modeling_utils import PreTrainedModel
17
  from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
18
- from transformers.utils import logging
19
  from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
20
  from .configuration_gpt2mimo import GPT2MIMOConfig
21
-
22
 
23
 
24
 
@@ -249,6 +250,114 @@ class GPT2Attention(nn.Module):
249
 
250
 
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  class GPT2MLP(nn.Module):
253
  def __init__(self, intermediate_size, config):
254
  super().__init__()
@@ -266,7 +375,7 @@ class GPT2MLP(nn.Module):
266
  return hidden_states
267
 
268
 
269
- GPT2_ATTENTION_CLASSES = {"eager": GPT2Attention}
270
 
271
 
272
  class GPT2Block(nn.Module):
@@ -533,7 +642,12 @@ class GPT2MIMOModel(GPT2PreTrainedModel):
533
  if self._attn_implementation == "flash_attention_2":
534
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
535
  elif _use_sdpa:
536
- raise ValueError("The `sdpa` implementation is REMOVED BY ME")
 
 
 
 
 
537
  else:
538
  if attention_mask is not None:
539
  # We create a 3D attention mask from a 2D tensor mask.
@@ -559,7 +673,9 @@ class GPT2MIMOModel(GPT2PreTrainedModel):
559
  if encoder_attention_mask is None:
560
  encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
561
  if _use_sdpa:
562
- raise ValueError("The `sdpa` implementation is REMOVED BY ME")
 
 
563
  elif not self._attn_implementation == "flash_attention_2":
564
  encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
565
  else:
 
4
 
5
  import torch
6
  import torch.utils.checkpoint
7
+ from packaging import version
8
  from torch import nn
9
  from torch.nn import CrossEntropyLoss
10
 
 
16
  )
17
  from transformers.modeling_utils import PreTrainedModel
18
  from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
19
+ from transformers.utils import logging, is_flash_attn_2_available, get_torch_version
20
  from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
21
  from .configuration_gpt2mimo import GPT2MIMOConfig
22
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
23
 
24
 
25
 
 
250
 
251
 
252
 
253
+ class GPT2SdpaAttention(GPT2Attention):
254
+ """
255
+ GPT2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
256
+ `GPT2Attention` as the weights of the module stays untouched. The only changes are on the forward pass
257
+ to adapt to the SDPA API.
258
+ """
259
+
260
+ def __init__(self, *args, **kwargs):
261
+ super().__init__(*args, **kwargs)
262
+
263
+ # Idea adapted from transformers.models.bert.modeling_bert.BertSdpaSelfAttention.__init__
264
+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
265
+ # attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0.
266
+ # Reference: https://github.com/pytorch/pytorch/issues/112577
267
+ self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
268
+
269
+ def forward(
270
+ self,
271
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
272
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
273
+ attention_mask: Optional[torch.FloatTensor] = None,
274
+ head_mask: Optional[torch.FloatTensor] = None,
275
+ encoder_hidden_states: Optional[torch.Tensor] = None,
276
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
277
+ use_cache: Optional[bool] = False,
278
+ output_attentions: Optional[bool] = False,
279
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
280
+ if output_attentions or head_mask is not None:
281
+ logger.warning_once(
282
+ "`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
283
+ "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
284
+ "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
285
+ 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
286
+ )
287
+ return super().forward(
288
+ hidden_states=hidden_states,
289
+ layer_past=layer_past,
290
+ attention_mask=attention_mask,
291
+ head_mask=head_mask,
292
+ encoder_hidden_states=encoder_hidden_states,
293
+ encoder_attention_mask=encoder_attention_mask,
294
+ use_cache=use_cache,
295
+ output_attentions=output_attentions,
296
+ )
297
+
298
+ bsz, q_len, _ = hidden_states.size()
299
+
300
+ # Initial attention projections
301
+ is_cross_attention = encoder_hidden_states is not None
302
+ if is_cross_attention:
303
+ if not hasattr(self, "q_attn"):
304
+ raise ValueError(
305
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
306
+ "Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`."
307
+ )
308
+
309
+ query = self.q_attn(hidden_states)
310
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
311
+ attention_mask = encoder_attention_mask
312
+ else:
313
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
314
+
315
+ query = self._split_heads(query, self.num_heads, self.head_dim)
316
+ key = self._split_heads(key, self.num_heads, self.head_dim)
317
+ value = self._split_heads(value, self.num_heads, self.head_dim)
318
+
319
+ # Optional kv caching
320
+ if layer_past is not None:
321
+ past_key = layer_past[0]
322
+ past_value = layer_past[1]
323
+ key = torch.cat((past_key, key), dim=-2)
324
+ value = torch.cat((past_value, value), dim=-2)
325
+
326
+ present = None
327
+ if use_cache is True:
328
+ present = (key, value)
329
+
330
+ # Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
331
+ if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None:
332
+ query = query.contiguous()
333
+ key = key.contiguous()
334
+ value = value.contiguous()
335
+
336
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
337
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
338
+ is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False
339
+
340
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
341
+ query,
342
+ key,
343
+ value,
344
+ attn_mask=attention_mask,
345
+ dropout_p=self.attn_dropout.p if self.training else 0.0,
346
+ is_causal=is_causal,
347
+ )
348
+
349
+ # Reshape outputs
350
+ attn_output = attn_output.transpose(1, 2).contiguous()
351
+ attn_output = attn_output.view(bsz, q_len, self.embed_dim)
352
+
353
+ # Final projection
354
+ attn_output = self.c_proj(attn_output)
355
+ attn_output = self.resid_dropout(attn_output)
356
+
357
+ return attn_output, present, None
358
+
359
+
360
+
361
  class GPT2MLP(nn.Module):
362
  def __init__(self, intermediate_size, config):
363
  super().__init__()
 
375
  return hidden_states
376
 
377
 
378
+ GPT2_ATTENTION_CLASSES = {"eager": GPT2Attention, "sdpa": GPT2SdpaAttention}
379
 
380
 
381
  class GPT2Block(nn.Module):
 
642
  if self._attn_implementation == "flash_attention_2":
643
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
644
  elif _use_sdpa:
645
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
646
+ attention_mask=attention_mask,
647
+ input_shape=(batch_size, input_shape[-1]),
648
+ inputs_embeds=inputs_embeds,
649
+ past_key_values_length=past_length,
650
+ )
651
  else:
652
  if attention_mask is not None:
653
  # We create a 3D attention mask from a 2D tensor mask.
 
673
  if encoder_attention_mask is None:
674
  encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
675
  if _use_sdpa:
676
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
677
+ mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
678
+ )
679
  elif not self._attn_implementation == "flash_attention_2":
680
  encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
681
  else: