Upload model
Browse files- config.json +0 -1
- configuration_gpt2mimo.py +1 -2
- modeling_gpt2mimo.py +121 -5
config.json
CHANGED
|
@@ -31,6 +31,5 @@
|
|
| 31 |
"torch_dtype": "float32",
|
| 32 |
"transformers_version": "4.41.1",
|
| 33 |
"use_cache": true,
|
| 34 |
-
"attn_implementation":"eager",
|
| 35 |
"vocab_size": 50257
|
| 36 |
}
|
|
|
|
| 31 |
"torch_dtype": "float32",
|
| 32 |
"transformers_version": "4.41.1",
|
| 33 |
"use_cache": true,
|
|
|
|
| 34 |
"vocab_size": 50257
|
| 35 |
}
|
configuration_gpt2mimo.py
CHANGED
|
@@ -159,7 +159,6 @@ class GPT2MIMOConfig(PretrainedConfig):
|
|
| 159 |
eos_token_id=50256,
|
| 160 |
scale_attn_by_inverse_layer_idx=False,
|
| 161 |
reorder_and_upcast_attn=False,
|
| 162 |
-
attn_implementation="eager",
|
| 163 |
**kwargs,
|
| 164 |
):
|
| 165 |
self.vocab_size = vocab_size
|
|
@@ -187,4 +186,4 @@ class GPT2MIMOConfig(PretrainedConfig):
|
|
| 187 |
self.bos_token_id = bos_token_id
|
| 188 |
self.eos_token_id = eos_token_id
|
| 189 |
|
| 190 |
-
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id,
|
|
|
|
| 159 |
eos_token_id=50256,
|
| 160 |
scale_attn_by_inverse_layer_idx=False,
|
| 161 |
reorder_and_upcast_attn=False,
|
|
|
|
| 162 |
**kwargs,
|
| 163 |
):
|
| 164 |
self.vocab_size = vocab_size
|
|
|
|
| 186 |
self.bos_token_id = bos_token_id
|
| 187 |
self.eos_token_id = eos_token_id
|
| 188 |
|
| 189 |
+
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
modeling_gpt2mimo.py
CHANGED
|
@@ -4,6 +4,7 @@ from typing import Optional, Tuple, Union
|
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import torch.utils.checkpoint
|
|
|
|
| 7 |
from torch import nn
|
| 8 |
from torch.nn import CrossEntropyLoss
|
| 9 |
|
|
@@ -15,10 +16,10 @@ from transformers.modeling_outputs import (
|
|
| 15 |
)
|
| 16 |
from transformers.modeling_utils import PreTrainedModel
|
| 17 |
from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
| 18 |
-
from transformers.utils import logging
|
| 19 |
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
| 20 |
from .configuration_gpt2mimo import GPT2MIMOConfig
|
| 21 |
-
|
| 22 |
|
| 23 |
|
| 24 |
|
|
@@ -249,6 +250,114 @@ class GPT2Attention(nn.Module):
|
|
| 249 |
|
| 250 |
|
| 251 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
class GPT2MLP(nn.Module):
|
| 253 |
def __init__(self, intermediate_size, config):
|
| 254 |
super().__init__()
|
|
@@ -266,7 +375,7 @@ class GPT2MLP(nn.Module):
|
|
| 266 |
return hidden_states
|
| 267 |
|
| 268 |
|
| 269 |
-
GPT2_ATTENTION_CLASSES = {"eager": GPT2Attention}
|
| 270 |
|
| 271 |
|
| 272 |
class GPT2Block(nn.Module):
|
|
@@ -533,7 +642,12 @@ class GPT2MIMOModel(GPT2PreTrainedModel):
|
|
| 533 |
if self._attn_implementation == "flash_attention_2":
|
| 534 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
| 535 |
elif _use_sdpa:
|
| 536 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
else:
|
| 538 |
if attention_mask is not None:
|
| 539 |
# We create a 3D attention mask from a 2D tensor mask.
|
|
@@ -559,7 +673,9 @@ class GPT2MIMOModel(GPT2PreTrainedModel):
|
|
| 559 |
if encoder_attention_mask is None:
|
| 560 |
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 561 |
if _use_sdpa:
|
| 562 |
-
|
|
|
|
|
|
|
| 563 |
elif not self._attn_implementation == "flash_attention_2":
|
| 564 |
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 565 |
else:
|
|
|
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import torch.utils.checkpoint
|
| 7 |
+
from packaging import version
|
| 8 |
from torch import nn
|
| 9 |
from torch.nn import CrossEntropyLoss
|
| 10 |
|
|
|
|
| 16 |
)
|
| 17 |
from transformers.modeling_utils import PreTrainedModel
|
| 18 |
from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
| 19 |
+
from transformers.utils import logging, is_flash_attn_2_available, get_torch_version
|
| 20 |
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
| 21 |
from .configuration_gpt2mimo import GPT2MIMOConfig
|
| 22 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
|
| 23 |
|
| 24 |
|
| 25 |
|
|
|
|
| 250 |
|
| 251 |
|
| 252 |
|
| 253 |
+
class GPT2SdpaAttention(GPT2Attention):
|
| 254 |
+
"""
|
| 255 |
+
GPT2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
| 256 |
+
`GPT2Attention` as the weights of the module stays untouched. The only changes are on the forward pass
|
| 257 |
+
to adapt to the SDPA API.
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
def __init__(self, *args, **kwargs):
|
| 261 |
+
super().__init__(*args, **kwargs)
|
| 262 |
+
|
| 263 |
+
# Idea adapted from transformers.models.bert.modeling_bert.BertSdpaSelfAttention.__init__
|
| 264 |
+
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
| 265 |
+
# attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0.
|
| 266 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
| 267 |
+
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
|
| 268 |
+
|
| 269 |
+
def forward(
|
| 270 |
+
self,
|
| 271 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
| 272 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
| 273 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 274 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 275 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 276 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 277 |
+
use_cache: Optional[bool] = False,
|
| 278 |
+
output_attentions: Optional[bool] = False,
|
| 279 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
| 280 |
+
if output_attentions or head_mask is not None:
|
| 281 |
+
logger.warning_once(
|
| 282 |
+
"`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
| 283 |
+
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
|
| 284 |
+
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
|
| 285 |
+
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 286 |
+
)
|
| 287 |
+
return super().forward(
|
| 288 |
+
hidden_states=hidden_states,
|
| 289 |
+
layer_past=layer_past,
|
| 290 |
+
attention_mask=attention_mask,
|
| 291 |
+
head_mask=head_mask,
|
| 292 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 293 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 294 |
+
use_cache=use_cache,
|
| 295 |
+
output_attentions=output_attentions,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
bsz, q_len, _ = hidden_states.size()
|
| 299 |
+
|
| 300 |
+
# Initial attention projections
|
| 301 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 302 |
+
if is_cross_attention:
|
| 303 |
+
if not hasattr(self, "q_attn"):
|
| 304 |
+
raise ValueError(
|
| 305 |
+
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
| 306 |
+
"Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`."
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
query = self.q_attn(hidden_states)
|
| 310 |
+
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
| 311 |
+
attention_mask = encoder_attention_mask
|
| 312 |
+
else:
|
| 313 |
+
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
| 314 |
+
|
| 315 |
+
query = self._split_heads(query, self.num_heads, self.head_dim)
|
| 316 |
+
key = self._split_heads(key, self.num_heads, self.head_dim)
|
| 317 |
+
value = self._split_heads(value, self.num_heads, self.head_dim)
|
| 318 |
+
|
| 319 |
+
# Optional kv caching
|
| 320 |
+
if layer_past is not None:
|
| 321 |
+
past_key = layer_past[0]
|
| 322 |
+
past_value = layer_past[1]
|
| 323 |
+
key = torch.cat((past_key, key), dim=-2)
|
| 324 |
+
value = torch.cat((past_value, value), dim=-2)
|
| 325 |
+
|
| 326 |
+
present = None
|
| 327 |
+
if use_cache is True:
|
| 328 |
+
present = (key, value)
|
| 329 |
+
|
| 330 |
+
# Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
|
| 331 |
+
if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None:
|
| 332 |
+
query = query.contiguous()
|
| 333 |
+
key = key.contiguous()
|
| 334 |
+
value = value.contiguous()
|
| 335 |
+
|
| 336 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
| 337 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
| 338 |
+
is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False
|
| 339 |
+
|
| 340 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 341 |
+
query,
|
| 342 |
+
key,
|
| 343 |
+
value,
|
| 344 |
+
attn_mask=attention_mask,
|
| 345 |
+
dropout_p=self.attn_dropout.p if self.training else 0.0,
|
| 346 |
+
is_causal=is_causal,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
# Reshape outputs
|
| 350 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 351 |
+
attn_output = attn_output.view(bsz, q_len, self.embed_dim)
|
| 352 |
+
|
| 353 |
+
# Final projection
|
| 354 |
+
attn_output = self.c_proj(attn_output)
|
| 355 |
+
attn_output = self.resid_dropout(attn_output)
|
| 356 |
+
|
| 357 |
+
return attn_output, present, None
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
|
| 361 |
class GPT2MLP(nn.Module):
|
| 362 |
def __init__(self, intermediate_size, config):
|
| 363 |
super().__init__()
|
|
|
|
| 375 |
return hidden_states
|
| 376 |
|
| 377 |
|
| 378 |
+
GPT2_ATTENTION_CLASSES = {"eager": GPT2Attention, "sdpa": GPT2SdpaAttention}
|
| 379 |
|
| 380 |
|
| 381 |
class GPT2Block(nn.Module):
|
|
|
|
| 642 |
if self._attn_implementation == "flash_attention_2":
|
| 643 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
| 644 |
elif _use_sdpa:
|
| 645 |
+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
| 646 |
+
attention_mask=attention_mask,
|
| 647 |
+
input_shape=(batch_size, input_shape[-1]),
|
| 648 |
+
inputs_embeds=inputs_embeds,
|
| 649 |
+
past_key_values_length=past_length,
|
| 650 |
+
)
|
| 651 |
else:
|
| 652 |
if attention_mask is not None:
|
| 653 |
# We create a 3D attention mask from a 2D tensor mask.
|
|
|
|
| 673 |
if encoder_attention_mask is None:
|
| 674 |
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 675 |
if _use_sdpa:
|
| 676 |
+
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
| 677 |
+
mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
|
| 678 |
+
)
|
| 679 |
elif not self._attn_implementation == "flash_attention_2":
|
| 680 |
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 681 |
else:
|