LieUr commited on
Commit
3c2f18f
·
verified ·
1 Parent(s): 901575d

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_function": "gelu_new",
3
+ "architectures": [
4
+ "SVD_GPT2LMHeadModel"
5
+ ],
6
+ "attn_pdrop": 0.1,
7
+ "bos_token_id": 50256,
8
+ "embd_pdrop": 0.1,
9
+ "eos_token_id": 50256,
10
+ "initializer_range": 0.02,
11
+ "layer_norm_epsilon": 1e-05,
12
+ "model_type": "svd_gpt2",
13
+ "n_ctx": 1024,
14
+ "n_embd": 768,
15
+ "n_head": 12,
16
+ "n_inner": null,
17
+ "n_layer": 12,
18
+ "n_positions": 1024,
19
+ "ratio": 1.0,
20
+ "reorder_and_upcast_attn": false,
21
+ "resid_pdrop": 0.1,
22
+ "scale_attn_by_inverse_layer_idx": false,
23
+ "scale_attn_weights": true,
24
+ "summary_activation": null,
25
+ "summary_first_dropout": 0.1,
26
+ "summary_proj_to_labels": true,
27
+ "summary_type": "cls_index",
28
+ "summary_use_proj": true,
29
+ "task_specific_params": {
30
+ "text-generation": {
31
+ "do_sample": true,
32
+ "max_length": 50
33
+ }
34
+ },
35
+ "torch_dtype": "float32",
36
+ "transformers_version": "4.52.4",
37
+ "use_cache": true,
38
+ "vocab_size": 50257
39
+ }
configuration_svd_gpt2.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2Config
2
+
3
+ class SVDGPT2Config(GPT2Config):
4
+ model_type = "svd_gpt2"
5
+
6
+ def __init__(
7
+ self,
8
+ vocab_size=50257,
9
+ n_positions=1024,
10
+ n_embd=768,
11
+ n_layer=12,
12
+ n_head=12,
13
+ n_inner=None,
14
+ activation_function="gelu_new",
15
+ resid_pdrop=0.1,
16
+ embd_pdrop=0.1,
17
+ attn_pdrop=0.1,
18
+ layer_norm_epsilon=1e-5,
19
+ initializer_range=0.02,
20
+ summary_type="cls_index",
21
+ summary_use_proj=True,
22
+ summary_activation=None,
23
+ summary_proj_to_labels=True,
24
+ summary_first_dropout=0.1,
25
+ scale_attn_weights=True,
26
+ use_cache=True,
27
+ bos_token_id=50256,
28
+ eos_token_id=50256,
29
+ scale_attn_by_inverse_layer_idx=False,
30
+ reorder_and_upcast_attn=False,
31
+ ratio = 1.0,
32
+ **kwargs,
33
+ ):
34
+ super().__init__(
35
+ vocab_size=vocab_size,
36
+ n_positions=n_positions,
37
+ n_embd=n_embd,
38
+ n_layer=n_layer,
39
+ n_head=n_head,
40
+ n_inner=n_inner,
41
+ activation_function=activation_function,
42
+ resid_pdrop=resid_pdrop,
43
+ embd_pdrop=embd_pdrop,
44
+ attn_pdrop=attn_pdrop,
45
+ layer_norm_epsilon=layer_norm_epsilon,
46
+ initializer_range=initializer_range,
47
+ summary_type=summary_type,
48
+ summary_use_proj=summary_use_proj,
49
+ summary_activation=summary_activation,
50
+ summary_proj_to_labels=summary_proj_to_labels,
51
+ summary_first_dropout=summary_first_dropout,
52
+ scale_attn_weights=scale_attn_weights,
53
+ use_cache=use_cache,
54
+ bos_token_id=bos_token_id,
55
+ eos_token_id=eos_token_id,
56
+ scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx,
57
+ reorder_and_upcast_attn=reorder_and_upcast_attn,
58
+ **kwargs
59
+ )
60
+
61
+ ## SVD-specific parameters
62
+ self.ratio = ratio
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 50256,
4
+ "eos_token_id": 50256,
5
+ "transformers_version": "4.52.4"
6
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0c99147a9d386d814567417baa31e759678a814e2dacba489221c19ea4c5a47
3
+ size 497641368
modeling_svd_gpt2.py ADDED
@@ -0,0 +1,1783 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import warnings
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Tuple, Union
6
+
7
+ import torch
8
+ import torch.utils.checkpoint
9
+ from packaging import version
10
+ from torch import nn
11
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
12
+
13
+ from transformers.activations import ACT2FN
14
+ from transformers.generation import GenerationMixin
15
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
16
+ from transformers.modeling_outputs import (
17
+ BaseModelOutputWithPastAndCrossAttentions,
18
+ CausalLMOutputWithCrossAttentions,
19
+ QuestionAnsweringModelOutput,
20
+ SequenceClassifierOutputWithPast,
21
+ TokenClassifierOutput,
22
+ )
23
+ from transformers.modeling_utils import PreTrainedModel, SequenceSummary
24
+ from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
25
+ from transformers.utils import (
26
+ ModelOutput,
27
+ add_code_sample_docstrings,
28
+ add_start_docstrings,
29
+ add_start_docstrings_to_model_forward,
30
+ get_torch_version,
31
+ is_flash_attn_2_available,
32
+ is_flash_attn_greater_or_equal_2_10,
33
+ logging,
34
+ replace_return_docstrings,
35
+ )
36
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
37
+ try:
38
+ from .configuration_svd_gpt2 import SVDGPT2Config
39
+ except:
40
+ from configuration_svd_gpt2 import SVDGPT2Config
41
+
42
+ if is_flash_attn_2_available():
43
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ _CHECKPOINT_FOR_DOC = "svd_gpt2"
49
+ _CONFIG_FOR_DOC = "SVDGPT2Config"
50
+
51
+ class SVD_Linear(nn.Module):
52
+ def __init__(self, in_features: int, out_features: int, bias: bool = False, ratio : float = 1.0):
53
+ super(SVD_Linear, self).__init__()
54
+
55
+ self.in_features = in_features
56
+ self.out_features = out_features
57
+ self.ratio = ratio
58
+ rank = max(1, int((ratio * in_features * out_features) / (in_features + out_features)))
59
+ self.rank = rank = min(rank, min(in_features, out_features)) # clamp to valid range
60
+
61
+ self.weight_v = nn.Linear(in_features, rank, bias = False) # SVD down projection
62
+ self.weight_u = nn.Linear(rank, out_features, bias = bias) # SVD up projection
63
+
64
+ def forward(self, input):
65
+ return self.weight_u(self.weight_v(input))
66
+
67
+ class SVD_GPT2Attention(nn.Module):
68
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
69
+ super().__init__()
70
+ self.config = config
71
+ max_positions = config.max_position_embeddings
72
+ self.register_buffer(
73
+ "bias",
74
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
75
+ 1, 1, max_positions, max_positions
76
+ ),
77
+ persistent=False,
78
+ )
79
+ self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
80
+
81
+ self.embed_dim = config.hidden_size
82
+ self.num_heads = config.num_attention_heads
83
+ self.head_dim = self.embed_dim // self.num_heads
84
+ self.split_size = self.embed_dim
85
+ if self.head_dim * self.num_heads != self.embed_dim:
86
+ raise ValueError(
87
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
88
+ f" {self.num_heads})."
89
+ )
90
+
91
+ self.scale_attn_weights = config.scale_attn_weights
92
+ self.is_cross_attention = is_cross_attention
93
+
94
+ # Layer-wise attention scaling, reordering, and upcasting
95
+ self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
96
+ self.layer_idx = layer_idx
97
+ self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
98
+
99
+ # ============== Modifications for SVD
100
+ self.ratio = ratio = config.ratio
101
+ self.q_proj = SVD_Linear(self.embed_dim, self.embed_dim, bias=True, ratio=ratio)
102
+ self.k_proj = SVD_Linear(self.embed_dim, self.embed_dim, bias=True, ratio=ratio)
103
+ self.v_proj = SVD_Linear(self.embed_dim, self.embed_dim, bias=True, ratio=ratio)
104
+ self.c_proj = SVD_Linear(self.embed_dim, self.embed_dim, bias=True, ratio=ratio)
105
+ # ===============
106
+
107
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
108
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
109
+ self.is_causal = True
110
+
111
+ self.pruned_heads = set()
112
+
113
+ def prune_heads(self, heads):
114
+ raise ValueError("Not supported currently")
115
+
116
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
117
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
118
+
119
+ if self.scale_attn_weights:
120
+ attn_weights = attn_weights / torch.full(
121
+ [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
122
+ )
123
+
124
+ # Layer-wise attention scaling
125
+ if self.scale_attn_by_inverse_layer_idx:
126
+ attn_weights = attn_weights / float(self.layer_idx + 1)
127
+
128
+ if not self.is_cross_attention:
129
+ # if only "normal" attention layer implements causal mask
130
+ query_length, key_length = query.size(-2), key.size(-2)
131
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
132
+ mask_value = torch.finfo(attn_weights.dtype).min
133
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
134
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
135
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
136
+ attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
137
+
138
+ if attention_mask is not None:
139
+ # Apply the attention mask
140
+ attn_weights = attn_weights + attention_mask
141
+
142
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
143
+
144
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
145
+ attn_weights = attn_weights.type(value.dtype)
146
+ attn_weights = self.attn_dropout(attn_weights)
147
+
148
+ # Mask heads if we want to
149
+ if head_mask is not None:
150
+ attn_weights = attn_weights * head_mask
151
+
152
+ attn_output = torch.matmul(attn_weights, value)
153
+
154
+ return attn_output, attn_weights
155
+
156
+ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
157
+ # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
158
+ bsz, num_heads, q_seq_len, dk = query.size()
159
+ _, _, k_seq_len, _ = key.size()
160
+
161
+ # Preallocate attn_weights for `baddbmm`
162
+ attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
163
+
164
+ # Compute Scale Factor
165
+ scale_factor = 1.0
166
+ if self.scale_attn_weights:
167
+ scale_factor /= float(value.size(-1)) ** 0.5
168
+
169
+ if self.scale_attn_by_inverse_layer_idx:
170
+ scale_factor /= float(self.layer_idx + 1)
171
+
172
+ # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
173
+ with torch.amp.autocast(query.device.type, enabled=False):
174
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
175
+ attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
176
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
177
+
178
+ if not self.is_cross_attention:
179
+ # if only "normal" attention layer implements causal mask
180
+ query_length, key_length = query.size(-2), key.size(-2)
181
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
182
+ mask_value = torch.finfo(attn_weights.dtype).min
183
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
184
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
185
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
186
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
187
+
188
+ if attention_mask is not None:
189
+ # Apply the attention mask
190
+ attn_weights = attn_weights + attention_mask
191
+
192
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
193
+
194
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
195
+ if attn_weights.dtype != torch.float32:
196
+ raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
197
+ attn_weights = attn_weights.type(value.dtype)
198
+ attn_weights = self.attn_dropout(attn_weights)
199
+
200
+ # Mask heads if we want to
201
+ if head_mask is not None:
202
+ attn_weights = attn_weights * head_mask
203
+
204
+ attn_output = torch.matmul(attn_weights, value)
205
+
206
+ return attn_output, attn_weights
207
+
208
+ def _split_heads(self, tensor, num_heads, attn_head_size):
209
+ """
210
+ Splits hidden_size dim into attn_head_size and num_heads
211
+ """
212
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
213
+ tensor = tensor.view(new_shape)
214
+ return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
215
+
216
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
217
+ """
218
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
219
+ """
220
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
221
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
222
+ return tensor.view(new_shape)
223
+
224
+ def forward(
225
+ self,
226
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
227
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
228
+ attention_mask: Optional[torch.FloatTensor] = None,
229
+ head_mask: Optional[torch.FloatTensor] = None,
230
+ encoder_hidden_states: Optional[torch.Tensor] = None,
231
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
232
+ use_cache: Optional[bool] = False,
233
+ output_attentions: Optional[bool] = False,
234
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
235
+ if encoder_hidden_states is not None:
236
+ if not hasattr(self, "q_attn"):
237
+ raise ValueError(
238
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
239
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
240
+ )
241
+ attention_mask = encoder_attention_mask
242
+
243
+ query = self.q_proj(hidden_states)
244
+ key = self.k_proj(hidden_states)
245
+ value = self.v_proj(hidden_states)
246
+
247
+ query = self._split_heads(query, self.num_heads, self.head_dim)
248
+ key = self._split_heads(key, self.num_heads, self.head_dim)
249
+ value = self._split_heads(value, self.num_heads, self.head_dim)
250
+
251
+
252
+ if layer_past is not None:
253
+ past_key, past_value = layer_past
254
+ key = torch.cat((past_key, key), dim=-2)
255
+ value = torch.cat((past_value, value), dim=-2)
256
+
257
+ if use_cache is True:
258
+ present = (key, value)
259
+ else:
260
+ present = None
261
+
262
+ if self.reorder_and_upcast_attn:
263
+ attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
264
+ else:
265
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
266
+
267
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
268
+ attn_output = self.c_proj(attn_output)
269
+ attn_output = self.resid_dropout(attn_output)
270
+
271
+ outputs = (attn_output, present)
272
+ if output_attentions:
273
+ outputs += (attn_weights,)
274
+
275
+ return outputs # a, present, (attentions)
276
+
277
+ class SVD_GPT2FlashAttention2(SVD_GPT2Attention):
278
+ """
279
+ GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays
280
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
281
+ flash attention and deal with padding tokens in case the input contains any of them.
282
+ """
283
+
284
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
285
+ def __init__(self, *args, **kwargs):
286
+ super().__init__(*args, **kwargs)
287
+
288
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
289
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
290
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
291
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
292
+
293
+ def forward(
294
+ self,
295
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
296
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
297
+ attention_mask: Optional[torch.FloatTensor] = None,
298
+ head_mask: Optional[torch.FloatTensor] = None,
299
+ encoder_hidden_states: Optional[torch.Tensor] = None,
300
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
301
+ use_cache: Optional[bool] = False,
302
+ output_attentions: Optional[bool] = False,
303
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
304
+ bsz, _, _ = hidden_states.size()
305
+ if encoder_hidden_states is not None:
306
+ attention_mask = encoder_attention_mask
307
+
308
+ query = self.q_proj(hidden_states)
309
+ key = self.k_proj(hidden_states)
310
+ value = self.v_proj(hidden_states)
311
+
312
+ query = self._split_heads(query, self.num_heads, self.head_dim)
313
+ key = self._split_heads(key, self.num_heads, self.head_dim)
314
+ value = self._split_heads(value, self.num_heads, self.head_dim)
315
+
316
+ if layer_past is not None:
317
+ past_key = layer_past[0]
318
+ past_value = layer_past[1]
319
+ key = torch.cat((past_key, key), dim=-2)
320
+ value = torch.cat((past_value, value), dim=-2)
321
+
322
+ present = None
323
+ if use_cache is True:
324
+ present = (key, value)
325
+
326
+ query_length = query.shape[2]
327
+ tgt_len = key.shape[2]
328
+
329
+ # Flash attention requires the input to have the shape
330
+ # batch_size x seq_length x head_dim x hidden_dim
331
+ query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim)
332
+ key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
333
+ value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
334
+
335
+ attn_dropout = self.attn_dropout.p if self.training else 0.0
336
+
337
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
338
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
339
+ # cast them back in the correct dtype just to be sure everything works as expected.
340
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
341
+ # in fp32. (LlamaRMSNorm handles it correctly)
342
+
343
+ if query.dtype == torch.float32:
344
+ if torch.is_autocast_enabled():
345
+ target_dtype = torch.get_autocast_gpu_dtype()
346
+ # Handle the case where the model is quantized
347
+ elif hasattr(self.config, "_pre_quantization_dtype"):
348
+ target_dtype = self.config._pre_quantization_dtype
349
+ else:
350
+ target_dtype = self.c_proj.weight_u.weight.dtype
351
+
352
+ logger.warning_once(
353
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
354
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
355
+ f" {target_dtype}."
356
+ )
357
+
358
+ query = query.to(target_dtype)
359
+ key = key.to(target_dtype)
360
+ value = value.to(target_dtype)
361
+
362
+ attn_output = _flash_attention_forward(
363
+ query,
364
+ key,
365
+ value,
366
+ attention_mask,
367
+ query_length,
368
+ dropout=attn_dropout,
369
+ is_causal=self.is_causal,
370
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
371
+ )
372
+
373
+ attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
374
+ attn_output = self.c_proj(attn_weights_reshaped)
375
+ attn_output = self.resid_dropout(attn_output)
376
+
377
+ outputs = (attn_output, present)
378
+ if output_attentions:
379
+ outputs += (attn_weights_reshaped,)
380
+
381
+ return outputs
382
+
383
+ class SVD_GPT2SdpaAttention(SVD_GPT2Attention):
384
+ """
385
+ GPT2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
386
+ `GPT2Attention` as the weights of the module stays untouched. The only changes are on the forward pass
387
+ to adapt to the SDPA API.
388
+ """
389
+
390
+ def __init__(self, *args, **kwargs):
391
+ super().__init__(*args, **kwargs)
392
+
393
+ # Idea adapted from transformers.models.bert.modeling_bert.BertSdpaSelfAttention.__init__
394
+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
395
+ # attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0.
396
+ # Reference: https://github.com/pytorch/pytorch/issues/112577
397
+ self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
398
+
399
+ def forward(
400
+ self,
401
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
402
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
403
+ attention_mask: Optional[torch.FloatTensor] = None,
404
+ head_mask: Optional[torch.FloatTensor] = None,
405
+ encoder_hidden_states: Optional[torch.Tensor] = None,
406
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
407
+ use_cache: Optional[bool] = False,
408
+ output_attentions: Optional[bool] = False,
409
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
410
+ if output_attentions or head_mask is not None:
411
+ logger.warning_once(
412
+ "`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
413
+ "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
414
+ "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
415
+ 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
416
+ )
417
+ return super().forward(
418
+ hidden_states=hidden_states,
419
+ layer_past=layer_past,
420
+ attention_mask=attention_mask,
421
+ head_mask=head_mask,
422
+ encoder_hidden_states=encoder_hidden_states,
423
+ encoder_attention_mask=encoder_attention_mask,
424
+ use_cache=use_cache,
425
+ output_attentions=output_attentions,
426
+ )
427
+
428
+ bsz, q_len, _ = hidden_states.size()
429
+
430
+ # Initial attention projections
431
+ is_cross_attention = encoder_hidden_states is not None
432
+ if is_cross_attention:
433
+ attention_mask = encoder_attention_mask
434
+
435
+ query = self.q_proj(hidden_states)
436
+ key = self.k_proj(hidden_states)
437
+ value = self.v_proj(hidden_states)
438
+
439
+ query = self._split_heads(query, self.num_heads, self.head_dim)
440
+ key = self._split_heads(key, self.num_heads, self.head_dim)
441
+ value = self._split_heads(value, self.num_heads, self.head_dim)
442
+
443
+ # Optional kv caching
444
+ if layer_past is not None:
445
+ past_key = layer_past[0]
446
+ past_value = layer_past[1]
447
+ key = torch.cat((past_key, key), dim=-2)
448
+ value = torch.cat((past_value, value), dim=-2)
449
+
450
+ present = None
451
+ if use_cache is True:
452
+ present = (key, value)
453
+
454
+ # Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
455
+ if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None:
456
+ query = query.contiguous()
457
+ key = key.contiguous()
458
+ value = value.contiguous()
459
+
460
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
461
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
462
+ is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False
463
+
464
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
465
+ query,
466
+ key,
467
+ value,
468
+ attn_mask=attention_mask,
469
+ dropout_p=self.attn_dropout.p if self.training else 0.0,
470
+ is_causal=is_causal,
471
+ )
472
+
473
+ # Reshape outputs
474
+ attn_output = attn_output.transpose(1, 2).contiguous()
475
+ attn_output = attn_output.view(bsz, q_len, self.embed_dim)
476
+
477
+ # Final projection
478
+ attn_output = self.c_proj(attn_output)
479
+ attn_output = self.resid_dropout(attn_output)
480
+
481
+ return attn_output, present, None
482
+
483
+ class SVD_GPT2MLP(nn.Module):
484
+ def __init__(self, intermediate_size = None, config = None,**kwargs):
485
+ super().__init__()
486
+
487
+ self.config = config
488
+
489
+ self.hidden_size = hidden_size = config.hidden_size
490
+ if intermediate_size is None:
491
+ intermediate_size = config.n_inner if config.n_inner is not None else 4 * hidden_size
492
+ self.intermediate_size = intermediate_size
493
+ self.ratio = ratio = config.ratio
494
+
495
+ self.c_fc = SVD_Linear(hidden_size, intermediate_size, bias=True, ratio=ratio)
496
+ self.c_proj = SVD_Linear(intermediate_size, hidden_size, bias=True, ratio=ratio)
497
+ self.act = ACT2FN[config.activation_function]
498
+ self.dropout = nn.Dropout(config.resid_pdrop)
499
+
500
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
501
+ hidden_states = self.c_fc(hidden_states)
502
+ hidden_states = self.act(hidden_states)
503
+ hidden_states = self.c_proj(hidden_states)
504
+ hidden_states = self.dropout(hidden_states)
505
+ return hidden_states
506
+
507
+
508
+ SVD_GPT2_ATTENTION_CLASSES = {"eager": SVD_GPT2Attention, "flash_attention_2": SVD_GPT2FlashAttention2, "sdpa": SVD_GPT2SdpaAttention}
509
+
510
+ class SVD_GPT2Block(nn.Module):
511
+ def __init__(self, config, layer_idx=None):
512
+ super().__init__()
513
+ hidden_size = config.hidden_size
514
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
515
+ attention_class = SVD_GPT2_ATTENTION_CLASSES[config._attn_implementation]
516
+
517
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
518
+ self.attn = attention_class(config=config, layer_idx=layer_idx)
519
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
520
+
521
+ if config.add_cross_attention:
522
+ self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx)
523
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
524
+
525
+ self.mlp = SVD_GPT2MLP(inner_dim, config)
526
+
527
+ def forward(
528
+ self,
529
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
530
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
531
+ attention_mask: Optional[torch.FloatTensor] = None,
532
+ head_mask: Optional[torch.FloatTensor] = None,
533
+ encoder_hidden_states: Optional[torch.Tensor] = None,
534
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
535
+ use_cache: Optional[bool] = False,
536
+ output_attentions: Optional[bool] = False,
537
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
538
+ residual = hidden_states
539
+ hidden_states = self.ln_1(hidden_states)
540
+ attn_outputs = self.attn(
541
+ hidden_states,
542
+ layer_past=layer_past,
543
+ attention_mask=attention_mask,
544
+ head_mask=head_mask,
545
+ use_cache=use_cache,
546
+ output_attentions=output_attentions,
547
+ )
548
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
549
+ outputs = attn_outputs[1:]
550
+ # residual connection
551
+ hidden_states = attn_output + residual
552
+
553
+ if encoder_hidden_states is not None:
554
+ # add one self-attention block for cross-attention
555
+ if not hasattr(self, "crossattention"):
556
+ raise ValueError(
557
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
558
+ "cross-attention layers by setting `config.add_cross_attention=True`"
559
+ )
560
+ residual = hidden_states
561
+ hidden_states = self.ln_cross_attn(hidden_states)
562
+ cross_attn_outputs = self.crossattention(
563
+ hidden_states,
564
+ attention_mask=attention_mask,
565
+ head_mask=head_mask,
566
+ encoder_hidden_states=encoder_hidden_states,
567
+ encoder_attention_mask=encoder_attention_mask,
568
+ output_attentions=output_attentions,
569
+ )
570
+ attn_output = cross_attn_outputs[0]
571
+ # residual connection
572
+ hidden_states = residual + attn_output
573
+ outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
574
+
575
+ residual = hidden_states
576
+ hidden_states = self.ln_2(hidden_states)
577
+ feed_forward_hidden_states = self.mlp(hidden_states)
578
+ # residual connection
579
+ hidden_states = residual + feed_forward_hidden_states
580
+
581
+ if use_cache:
582
+ outputs = (hidden_states,) + outputs
583
+ else:
584
+ outputs = (hidden_states,) + outputs[1:]
585
+
586
+ return outputs # hidden_states, present, (attentions, cross_attentions)
587
+
588
+
589
+ class SVD_GPT2PreTrainedModel(PreTrainedModel):
590
+ """
591
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
592
+ models for SVD-GPT2.
593
+ """
594
+ config_class = SVDGPT2Config
595
+ base_model_prefix = "transformer"
596
+ is_parallelizable = True
597
+ supports_gradient_checkpointing = True
598
+ _no_split_modules = ["SVD_GPT2Block"]
599
+ _skip_keys_device_placement = "past_key_values"
600
+ _supports_flash_attn_2 = True
601
+ _supports_sdpa = True
602
+
603
+ def __init__(self, *inputs, **kwargs):
604
+ super().__init__(*inputs, **kwargs)
605
+
606
+ def _init_weights(self, module):
607
+ """Initialize the weights for SVD-GPT2 modules."""
608
+ if isinstance(module, (nn.Linear, Conv1D)):
609
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
610
+ if module.bias is not None:
611
+ module.bias.data.zero_()
612
+ elif isinstance(module, nn.Embedding):
613
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
614
+ if module.padding_idx is not None:
615
+ module.weight.data[module.padding_idx].zero_()
616
+ elif isinstance(module, nn.LayerNorm):
617
+ module.bias.data.zero_()
618
+ module.weight.data.fill_(1.0)
619
+
620
+
621
+ @dataclass
622
+ class SVD_GPT2DoubleHeadsModelOutput(ModelOutput):
623
+ """
624
+ Base class for outputs of models predicting if two sentences are consecutive or not.
625
+
626
+ Args:
627
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
628
+ Language modeling loss.
629
+ mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
630
+ Multiple choice classification loss.
631
+ logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
632
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
633
+ mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
634
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
635
+ past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
636
+ Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads,
637
+ sequence_length, embed_size_per_head)`).
638
+
639
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
640
+ `past_key_values` input) to speed up sequential decoding.
641
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
642
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
643
+ shape `(batch_size, sequence_length, hidden_size)`.
644
+
645
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
646
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
647
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
648
+ sequence_length)`.
649
+
650
+ GPT2Attentions weights after the attention softmax, used to compute the weighted average in the
651
+ self-attention heads.
652
+ """
653
+
654
+ loss: Optional[torch.FloatTensor] = None
655
+ mc_loss: Optional[torch.FloatTensor] = None
656
+ logits: torch.FloatTensor = None
657
+ mc_logits: torch.FloatTensor = None
658
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
659
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
660
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
661
+
662
+
663
+ GPT2_START_DOCSTRING = r"""
664
+
665
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
666
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
667
+ etc.)
668
+
669
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
670
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
671
+ and behavior.
672
+
673
+ Parameters:
674
+ config ([`GPT2Config`]): Model configuration class with all the parameters of the model.
675
+ Initializing with a config file does not load the weights associated with the model, only the
676
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
677
+ """
678
+
679
+ GPT2_INPUTS_DOCSTRING = r"""
680
+ Args:
681
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
682
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
683
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
684
+ sequence tokens in the vocabulary.
685
+
686
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
687
+ `input_ids`.
688
+
689
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
690
+ [`PreTrainedTokenizer.__call__`] for details.
691
+
692
+ [What are input IDs?](../glossary#input-ids)
693
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
694
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
695
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
696
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
697
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
698
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
699
+
700
+ - 1 for tokens that are **not masked**,
701
+ - 0 for tokens that are **masked**.
702
+
703
+ If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
704
+ `past_key_values`. In other words, the `attention_mask` always has to have the length:
705
+ `len(past_key_values) + len(input_ids)`
706
+
707
+ [What are attention masks?](../glossary#attention-mask)
708
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
709
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
710
+ 1]`:
711
+
712
+ - 0 corresponds to a *sentence A* token,
713
+ - 1 corresponds to a *sentence B* token.
714
+
715
+ [What are token type IDs?](../glossary#token-type-ids)
716
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
717
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
718
+ config.max_position_embeddings - 1]`.
719
+
720
+ [What are position IDs?](../glossary#position-ids)
721
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
722
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
723
+
724
+ - 1 indicates the head is **not masked**,
725
+ - 0 indicates the head is **masked**.
726
+
727
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
728
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
729
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
730
+ model's internal embedding lookup matrix.
731
+
732
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
733
+ `past_key_values`).
734
+ use_cache (`bool`, *optional*):
735
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
736
+ `past_key_values`).
737
+ output_attentions (`bool`, *optional*):
738
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
739
+ tensors for more detail.
740
+ output_hidden_states (`bool`, *optional*):
741
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
742
+ more detail.
743
+ return_dict (`bool`, *optional*):
744
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
745
+ """
746
+ PARALLELIZE_DOCSTRING = r"""
747
+ This is an experimental feature and is a subject to change at a moment's notice.
748
+
749
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
750
+ it will evenly distribute blocks across all devices.
751
+
752
+ Args:
753
+ device_map (`Dict[int, list]`, *optional*):
754
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
755
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
756
+ have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
757
+ following number of attention modules:
758
+
759
+ - openai-community/gpt2: 12
760
+ - openai-community/gpt2-medium: 24
761
+ - openai-community/gpt2-large: 36
762
+ - openai-community/gpt2-xl: 48
763
+
764
+ Example:
765
+
766
+ ```python
767
+ # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
768
+ model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl")
769
+ device_map = {
770
+ 0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
771
+ 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
772
+ 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
773
+ 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
774
+ }
775
+ model.parallelize(device_map)
776
+ ```
777
+ """
778
+ DEPARALLELIZE_DOCSTRING = r"""
779
+ Moves the model to cpu from a model parallel state.
780
+
781
+ Example:
782
+
783
+ ```python
784
+ # On a 4 GPU machine with openai-community/gpt2-large:
785
+ model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large")
786
+ device_map = {
787
+ 0: [0, 1, 2, 3, 4, 5, 6, 7],
788
+ 1: [8, 9, 10, 11, 12, 13, 14, 15],
789
+ 2: [16, 17, 18, 19, 20, 21, 22, 23],
790
+ 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
791
+ }
792
+ model.parallelize(device_map) # Splits the model across several devices
793
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
794
+ ```
795
+ """
796
+
797
+
798
+ @add_start_docstrings(
799
+ "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
800
+ GPT2_START_DOCSTRING,
801
+ )
802
+ class SVD_GPT2Model(SVD_GPT2PreTrainedModel):
803
+ _supports_param_buffer_assignment = False
804
+
805
+ def __init__(self, config):
806
+ super().__init__(config)
807
+
808
+ self.embed_dim = config.hidden_size
809
+
810
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
811
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
812
+
813
+ self.drop = nn.Dropout(config.embd_pdrop)
814
+ self.h = nn.ModuleList([SVD_GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
815
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
816
+
817
+ # Model parallel
818
+ self.model_parallel = False
819
+ self.device_map = None
820
+ self.gradient_checkpointing = False
821
+ self._attn_implementation = config._attn_implementation
822
+
823
+ # Initialize weights and apply final processing
824
+ self.post_init()
825
+
826
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
827
+ def parallelize(self, device_map=None):
828
+ # Check validity of device_map
829
+ warnings.warn(
830
+ "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
831
+ " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
832
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
833
+ " ...}",
834
+ FutureWarning,
835
+ )
836
+ self.device_map = (
837
+ get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
838
+ )
839
+ assert_device_map(self.device_map, len(self.h))
840
+ self.model_parallel = True
841
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
842
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
843
+ self.wte = self.wte.to(self.first_device)
844
+ self.wpe = self.wpe.to(self.first_device)
845
+ # Load onto devices
846
+ for k, v in self.device_map.items():
847
+ for block in v:
848
+ cuda_device = "cuda:" + str(k)
849
+ self.h[block] = self.h[block].to(cuda_device)
850
+ # ln_f to last
851
+ self.ln_f = self.ln_f.to(self.last_device)
852
+
853
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
854
+ def deparallelize(self):
855
+ warnings.warn(
856
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
857
+ FutureWarning,
858
+ )
859
+ self.model_parallel = False
860
+ self.device_map = None
861
+ self.first_device = "cpu"
862
+ self.last_device = "cpu"
863
+ self.wte = self.wte.to("cpu")
864
+ self.wpe = self.wpe.to("cpu")
865
+ for index in range(len(self.h)):
866
+ self.h[index] = self.h[index].to("cpu")
867
+ self.ln_f = self.ln_f.to("cpu")
868
+ torch.cuda.empty_cache()
869
+
870
+ def get_input_embeddings(self):
871
+ return self.wte
872
+
873
+ def set_input_embeddings(self, new_embeddings):
874
+ self.wte = new_embeddings
875
+
876
+ def _prune_heads(self, heads_to_prune):
877
+ """
878
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
879
+ """
880
+ for layer, heads in heads_to_prune.items():
881
+ self.h[layer].attn.prune_heads(heads)
882
+
883
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
884
+ @add_code_sample_docstrings(
885
+ checkpoint=_CHECKPOINT_FOR_DOC,
886
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
887
+ config_class=_CONFIG_FOR_DOC,
888
+ )
889
+ def forward(
890
+ self,
891
+ input_ids: Optional[torch.LongTensor] = None,
892
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
893
+ attention_mask: Optional[torch.FloatTensor] = None,
894
+ token_type_ids: Optional[torch.LongTensor] = None,
895
+ position_ids: Optional[torch.LongTensor] = None,
896
+ head_mask: Optional[torch.FloatTensor] = None,
897
+ inputs_embeds: Optional[torch.FloatTensor] = None,
898
+ encoder_hidden_states: Optional[torch.Tensor] = None,
899
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
900
+ use_cache: Optional[bool] = None,
901
+ output_attentions: Optional[bool] = None,
902
+ output_hidden_states: Optional[bool] = None,
903
+ return_dict: Optional[bool] = None,
904
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
905
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
906
+ output_hidden_states = (
907
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
908
+ )
909
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
910
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
911
+
912
+ if input_ids is not None and inputs_embeds is not None:
913
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
914
+ elif input_ids is not None:
915
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
916
+ input_shape = input_ids.size()
917
+ input_ids = input_ids.view(-1, input_shape[-1])
918
+ batch_size = input_ids.shape[0]
919
+ elif inputs_embeds is not None:
920
+ input_shape = inputs_embeds.size()[:-1]
921
+ batch_size = inputs_embeds.shape[0]
922
+ else:
923
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
924
+
925
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
926
+
927
+ if token_type_ids is not None:
928
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
929
+
930
+ if past_key_values is None:
931
+ past_length = 0
932
+ past_key_values = tuple([None] * len(self.h))
933
+ else:
934
+ past_length = past_key_values[0][0].size(-2)
935
+ if position_ids is None:
936
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
937
+ position_ids = position_ids.unsqueeze(0)
938
+
939
+ if inputs_embeds is None:
940
+ inputs_embeds = self.wte(input_ids)
941
+ position_embeds = self.wpe(position_ids)
942
+ hidden_states = inputs_embeds + position_embeds
943
+
944
+ # Attention mask.
945
+ _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
946
+ attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
947
+ if self._attn_implementation == "flash_attention_2":
948
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
949
+ elif _use_sdpa:
950
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
951
+ attention_mask=attention_mask,
952
+ input_shape=(batch_size, input_shape[-1]),
953
+ inputs_embeds=inputs_embeds,
954
+ past_key_values_length=past_length,
955
+ )
956
+ else:
957
+ if attention_mask is not None:
958
+ # We create a 3D attention mask from a 2D tensor mask.
959
+ # Sizes are [batch_size, 1, 1, to_seq_length]
960
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
961
+ # this attention mask is more simple than the triangular masking of causal attention
962
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
963
+ attention_mask = attention_mask[:, None, None, :]
964
+
965
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
966
+ # masked positions, this operation will create a tensor which is 0.0 for
967
+ # positions we want to attend and the dtype's smallest value for masked positions.
968
+ # Since we are adding it to the raw scores before the softmax, this is
969
+ # effectively the same as removing these entirely.
970
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
971
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
972
+
973
+ # If a 2D or 3D attention mask is provided for the cross-attention
974
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
975
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
976
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
977
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
978
+ if encoder_attention_mask is None:
979
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
980
+ if _use_sdpa:
981
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
982
+ mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
983
+ )
984
+ elif not self._attn_implementation == "flash_attention_2":
985
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
986
+ else:
987
+ encoder_attention_mask = None
988
+
989
+ # Prepare head mask if needed
990
+ # 1.0 in head_mask indicate we keep the head
991
+ # attention_probs has shape bsz x n_heads x N x N
992
+ # head_mask has shape n_layer x batch x n_heads x N x N
993
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
994
+
995
+ if token_type_ids is not None:
996
+ token_type_embeds = self.wte(token_type_ids)
997
+ hidden_states = hidden_states + token_type_embeds
998
+
999
+ hidden_states = self.drop(hidden_states)
1000
+
1001
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
1002
+
1003
+ if self.gradient_checkpointing and self.training:
1004
+ if use_cache:
1005
+ logger.warning_once(
1006
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1007
+ )
1008
+ use_cache = False
1009
+
1010
+ presents = () if use_cache else None
1011
+ all_self_attentions = () if output_attentions else None
1012
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
1013
+ all_hidden_states = () if output_hidden_states else None
1014
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
1015
+ # Model parallel
1016
+ if self.model_parallel:
1017
+ torch.cuda.set_device(hidden_states.device)
1018
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
1019
+ if layer_past is not None:
1020
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
1021
+ # Ensure that attention_mask is always on the same device as hidden_states
1022
+ if attention_mask is not None:
1023
+ attention_mask = attention_mask.to(hidden_states.device)
1024
+ if isinstance(head_mask, torch.Tensor):
1025
+ head_mask = head_mask.to(hidden_states.device)
1026
+ if output_hidden_states:
1027
+ all_hidden_states = all_hidden_states + (hidden_states,)
1028
+
1029
+ if self.gradient_checkpointing and self.training:
1030
+ outputs = self._gradient_checkpointing_func(
1031
+ block.__call__,
1032
+ hidden_states,
1033
+ None,
1034
+ attention_mask,
1035
+ head_mask[i],
1036
+ encoder_hidden_states,
1037
+ encoder_attention_mask,
1038
+ use_cache,
1039
+ output_attentions,
1040
+ )
1041
+ else:
1042
+ outputs = block(
1043
+ hidden_states,
1044
+ layer_past=layer_past,
1045
+ attention_mask=attention_mask,
1046
+ head_mask=head_mask[i],
1047
+ encoder_hidden_states=encoder_hidden_states,
1048
+ encoder_attention_mask=encoder_attention_mask,
1049
+ use_cache=use_cache,
1050
+ output_attentions=output_attentions,
1051
+ )
1052
+
1053
+ hidden_states = outputs[0]
1054
+ if use_cache is True:
1055
+ presents = presents + (outputs[1],)
1056
+
1057
+ if output_attentions:
1058
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
1059
+ if self.config.add_cross_attention:
1060
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
1061
+
1062
+ # Model Parallel: If it's the last layer for that device, put things on the next device
1063
+ if self.model_parallel:
1064
+ for k, v in self.device_map.items():
1065
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
1066
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
1067
+
1068
+ hidden_states = self.ln_f(hidden_states)
1069
+
1070
+ hidden_states = hidden_states.view(output_shape)
1071
+ # Add last hidden state
1072
+ if output_hidden_states:
1073
+ all_hidden_states = all_hidden_states + (hidden_states,)
1074
+
1075
+ if not return_dict:
1076
+ return tuple(
1077
+ v
1078
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
1079
+ if v is not None
1080
+ )
1081
+
1082
+ return BaseModelOutputWithPastAndCrossAttentions(
1083
+ last_hidden_state=hidden_states,
1084
+ past_key_values=presents,
1085
+ hidden_states=all_hidden_states,
1086
+ attentions=all_self_attentions,
1087
+ cross_attentions=all_cross_attentions,
1088
+ )
1089
+
1090
+
1091
+
1092
+ @add_start_docstrings(
1093
+ """
1094
+ The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
1095
+ embeddings).
1096
+ """,
1097
+ GPT2_START_DOCSTRING,
1098
+ )
1099
+ class SVD_GPT2LMHeadModel(SVD_GPT2PreTrainedModel, GenerationMixin):
1100
+ _tied_weights_keys = ["lm_head.weight"]
1101
+
1102
+ def __init__(self, config):
1103
+ super().__init__(config)
1104
+ self.transformer = SVD_GPT2Model(config)
1105
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1106
+
1107
+ # Model parallel
1108
+ self.model_parallel = False
1109
+ self.device_map = None
1110
+
1111
+ # Initialize weights and apply final processing
1112
+ self.post_init()
1113
+
1114
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1115
+ def parallelize(self, device_map=None):
1116
+ warnings.warn(
1117
+ "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
1118
+ " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
1119
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
1120
+ " 0, 'transformer.h.1': 1, ...}",
1121
+ FutureWarning,
1122
+ )
1123
+ self.device_map = (
1124
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1125
+ if device_map is None
1126
+ else device_map
1127
+ )
1128
+ assert_device_map(self.device_map, len(self.transformer.h))
1129
+ self.transformer.parallelize(self.device_map)
1130
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1131
+ self.model_parallel = True
1132
+
1133
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1134
+ def deparallelize(self):
1135
+ warnings.warn(
1136
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1137
+ FutureWarning,
1138
+ )
1139
+ self.transformer.deparallelize()
1140
+ self.transformer = self.transformer.to("cpu")
1141
+ self.lm_head = self.lm_head.to("cpu")
1142
+ self.model_parallel = False
1143
+ torch.cuda.empty_cache()
1144
+
1145
+ def get_output_embeddings(self):
1146
+ return self.lm_head
1147
+
1148
+ def set_output_embeddings(self, new_embeddings):
1149
+ self.lm_head = new_embeddings
1150
+
1151
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1152
+ @add_code_sample_docstrings(
1153
+ checkpoint=_CHECKPOINT_FOR_DOC,
1154
+ output_type=CausalLMOutputWithCrossAttentions,
1155
+ config_class=_CONFIG_FOR_DOC,
1156
+ )
1157
+ def forward(
1158
+ self,
1159
+ input_ids: Optional[torch.LongTensor] = None,
1160
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1161
+ attention_mask: Optional[torch.FloatTensor] = None,
1162
+ token_type_ids: Optional[torch.LongTensor] = None,
1163
+ position_ids: Optional[torch.LongTensor] = None,
1164
+ head_mask: Optional[torch.FloatTensor] = None,
1165
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1166
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1167
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1168
+ labels: Optional[torch.LongTensor] = None,
1169
+ use_cache: Optional[bool] = None,
1170
+ output_attentions: Optional[bool] = None,
1171
+ output_hidden_states: Optional[bool] = None,
1172
+ return_dict: Optional[bool] = None,
1173
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1174
+ r"""
1175
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1176
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1177
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1178
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1179
+ """
1180
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1181
+
1182
+ transformer_outputs = self.transformer(
1183
+ input_ids,
1184
+ past_key_values=past_key_values,
1185
+ attention_mask=attention_mask,
1186
+ token_type_ids=token_type_ids,
1187
+ position_ids=position_ids,
1188
+ head_mask=head_mask,
1189
+ inputs_embeds=inputs_embeds,
1190
+ encoder_hidden_states=encoder_hidden_states,
1191
+ encoder_attention_mask=encoder_attention_mask,
1192
+ use_cache=use_cache,
1193
+ output_attentions=output_attentions,
1194
+ output_hidden_states=output_hidden_states,
1195
+ return_dict=return_dict,
1196
+ )
1197
+ hidden_states = transformer_outputs[0]
1198
+
1199
+ # Set device for model parallelism
1200
+ if self.model_parallel:
1201
+ torch.cuda.set_device(self.transformer.first_device)
1202
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1203
+
1204
+ lm_logits = self.lm_head(hidden_states)
1205
+
1206
+ loss = None
1207
+ if labels is not None:
1208
+ # move labels to correct device to enable model parallelism
1209
+ labels = labels.to(lm_logits.device)
1210
+ # Shift so that tokens < n predict n
1211
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1212
+ shift_labels = labels[..., 1:].contiguous()
1213
+ # Flatten the tokens
1214
+ loss_fct = CrossEntropyLoss()
1215
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1216
+
1217
+ if not return_dict:
1218
+ output = (lm_logits,) + transformer_outputs[1:]
1219
+ return ((loss,) + output) if loss is not None else output
1220
+
1221
+ return CausalLMOutputWithCrossAttentions(
1222
+ loss=loss,
1223
+ logits=lm_logits,
1224
+ past_key_values=transformer_outputs.past_key_values,
1225
+ hidden_states=transformer_outputs.hidden_states,
1226
+ attentions=transformer_outputs.attentions,
1227
+ cross_attentions=transformer_outputs.cross_attentions,
1228
+ )
1229
+
1230
+ @staticmethod
1231
+ def _reorder_cache(
1232
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1233
+ ) -> Tuple[Tuple[torch.Tensor]]:
1234
+ """
1235
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1236
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1237
+ beam_idx at every generation step.
1238
+ """
1239
+ return tuple(
1240
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
1241
+ for layer_past in past_key_values
1242
+ )
1243
+
1244
+
1245
+ @add_start_docstrings(
1246
+ """
1247
+ The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
1248
+ RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
1249
+ input embeddings, the classification head takes as input the input of a specified classification token index in the
1250
+ input sequence).
1251
+ """,
1252
+ GPT2_START_DOCSTRING,
1253
+ )
1254
+ class SVD_GPT2DoubleHeadsModel(SVD_GPT2PreTrainedModel, GenerationMixin):
1255
+ _tied_weights_keys = ["lm_head.weight"]
1256
+
1257
+ def __init__(self, config):
1258
+ super().__init__(config)
1259
+ config.num_labels = 1
1260
+ self.transformer = SVD_GPT2Model(config)
1261
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1262
+ self.multiple_choice_head = SequenceSummary(config)
1263
+
1264
+ # Model parallel
1265
+ self.model_parallel = False
1266
+ self.device_map = None
1267
+
1268
+ # Initialize weights and apply final processing
1269
+ self.post_init()
1270
+
1271
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1272
+ def parallelize(self, device_map=None):
1273
+ warnings.warn(
1274
+ "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should"
1275
+ " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your"
1276
+ " own `device_map` but it needs to be a dictionary module_name to device, so for instance"
1277
+ " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}",
1278
+ FutureWarning,
1279
+ )
1280
+ self.device_map = (
1281
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1282
+ if device_map is None
1283
+ else device_map
1284
+ )
1285
+ assert_device_map(self.device_map, len(self.transformer.h))
1286
+ self.transformer.parallelize(self.device_map)
1287
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1288
+ self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device)
1289
+ self.model_parallel = True
1290
+
1291
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1292
+ def deparallelize(self):
1293
+ warnings.warn(
1294
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1295
+ FutureWarning,
1296
+ )
1297
+ self.transformer.deparallelize()
1298
+ self.transformer = self.transformer.to("cpu")
1299
+ self.lm_head = self.lm_head.to("cpu")
1300
+ self.multiple_choice_head = self.multiple_choice_head.to("cpu")
1301
+ self.model_parallel = False
1302
+ torch.cuda.empty_cache()
1303
+
1304
+ def get_output_embeddings(self):
1305
+ return self.lm_head
1306
+
1307
+ def set_output_embeddings(self, new_embeddings):
1308
+ self.lm_head = new_embeddings
1309
+
1310
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1311
+ @replace_return_docstrings(output_type=SVD_GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
1312
+ def forward(
1313
+ self,
1314
+ input_ids: Optional[torch.LongTensor] = None,
1315
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1316
+ attention_mask: Optional[torch.FloatTensor] = None,
1317
+ token_type_ids: Optional[torch.LongTensor] = None,
1318
+ position_ids: Optional[torch.LongTensor] = None,
1319
+ head_mask: Optional[torch.FloatTensor] = None,
1320
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1321
+ mc_token_ids: Optional[torch.LongTensor] = None,
1322
+ labels: Optional[torch.LongTensor] = None,
1323
+ mc_labels: Optional[torch.LongTensor] = None,
1324
+ use_cache: Optional[bool] = None,
1325
+ output_attentions: Optional[bool] = None,
1326
+ output_hidden_states: Optional[bool] = None,
1327
+ return_dict: Optional[bool] = None,
1328
+ **kwargs,
1329
+ ) -> Union[Tuple, SVD_GPT2DoubleHeadsModelOutput]:
1330
+ r"""
1331
+ mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
1332
+ Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
1333
+ 1]`.
1334
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1335
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1336
+ `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to
1337
+ `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
1338
+ mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
1339
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
1340
+ where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
1341
+
1342
+ Return:
1343
+
1344
+ Example:
1345
+
1346
+ ```python
1347
+ >>> import torch
1348
+ >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel
1349
+
1350
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
1351
+ >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2")
1352
+
1353
+ >>> # Add a [CLS] to the vocabulary (we should train it also!)
1354
+ >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"})
1355
+ >>> # Update the model embeddings with the new vocabulary size
1356
+ >>> embedding_layer = model.resize_token_embeddings(len(tokenizer))
1357
+
1358
+ >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
1359
+ >>> encoded_choices = [tokenizer.encode(s) for s in choices]
1360
+ >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
1361
+
1362
+ >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
1363
+ >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
1364
+
1365
+ >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
1366
+ >>> lm_logits = outputs.logits
1367
+ >>> mc_logits = outputs.mc_logits
1368
+ ```"""
1369
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1370
+
1371
+ transformer_outputs = self.transformer(
1372
+ input_ids,
1373
+ past_key_values=past_key_values,
1374
+ attention_mask=attention_mask,
1375
+ token_type_ids=token_type_ids,
1376
+ position_ids=position_ids,
1377
+ head_mask=head_mask,
1378
+ inputs_embeds=inputs_embeds,
1379
+ use_cache=use_cache,
1380
+ output_attentions=output_attentions,
1381
+ output_hidden_states=output_hidden_states,
1382
+ return_dict=return_dict,
1383
+ )
1384
+
1385
+ hidden_states = transformer_outputs[0]
1386
+
1387
+ # Set device for model parallelism
1388
+ if self.model_parallel:
1389
+ torch.cuda.set_device(self.transformer.first_device)
1390
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1391
+
1392
+ lm_logits = self.lm_head(hidden_states)
1393
+ mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
1394
+
1395
+ mc_loss = None
1396
+ if mc_labels is not None:
1397
+ loss_fct = CrossEntropyLoss()
1398
+ mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
1399
+ lm_loss = None
1400
+ if labels is not None:
1401
+ labels = labels.to(lm_logits.device)
1402
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1403
+ shift_labels = labels[..., 1:].contiguous()
1404
+ loss_fct = CrossEntropyLoss()
1405
+ lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1406
+
1407
+ if not return_dict:
1408
+ output = (lm_logits, mc_logits) + transformer_outputs[1:]
1409
+ if mc_loss is not None:
1410
+ output = (mc_loss,) + output
1411
+ return ((lm_loss,) + output) if lm_loss is not None else output
1412
+
1413
+ return SVD_GPT2DoubleHeadsModelOutput(
1414
+ loss=lm_loss,
1415
+ mc_loss=mc_loss,
1416
+ logits=lm_logits,
1417
+ mc_logits=mc_logits,
1418
+ past_key_values=transformer_outputs.past_key_values,
1419
+ hidden_states=transformer_outputs.hidden_states,
1420
+ attentions=transformer_outputs.attentions,
1421
+ )
1422
+
1423
+ @staticmethod
1424
+ def _reorder_cache(
1425
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1426
+ ) -> Tuple[Tuple[torch.Tensor]]:
1427
+ """
1428
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1429
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1430
+ beam_idx at every generation step.
1431
+ """
1432
+ return tuple(
1433
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
1434
+ for layer_past in past_key_values
1435
+ )
1436
+
1437
+
1438
+ @add_start_docstrings(
1439
+ """
1440
+ The GPT2 Model transformer with a sequence classification head on top (linear layer).
1441
+
1442
+ [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1443
+ (e.g. GPT-1) do.
1444
+
1445
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1446
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1447
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1448
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1449
+ each row of the batch).
1450
+ """,
1451
+ GPT2_START_DOCSTRING,
1452
+ )
1453
+ class SVD_GPT2ForSequenceClassification(SVD_GPT2PreTrainedModel):
1454
+ def __init__(self, config):
1455
+ super().__init__(config)
1456
+ self.num_labels = config.num_labels
1457
+ self.transformer = SVD_GPT2Model(config)
1458
+ self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
1459
+
1460
+ # Model parallel
1461
+ self.model_parallel = False
1462
+ self.device_map = None
1463
+
1464
+ # Initialize weights and apply final processing
1465
+ self.post_init()
1466
+
1467
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1468
+ @add_code_sample_docstrings(
1469
+ checkpoint="microsoft/DialogRPT-updown",
1470
+ output_type=SequenceClassifierOutputWithPast,
1471
+ config_class=_CONFIG_FOR_DOC,
1472
+ )
1473
+ def forward(
1474
+ self,
1475
+ input_ids: Optional[torch.LongTensor] = None,
1476
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1477
+ attention_mask: Optional[torch.FloatTensor] = None,
1478
+ token_type_ids: Optional[torch.LongTensor] = None,
1479
+ position_ids: Optional[torch.LongTensor] = None,
1480
+ head_mask: Optional[torch.FloatTensor] = None,
1481
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1482
+ labels: Optional[torch.LongTensor] = None,
1483
+ use_cache: Optional[bool] = None,
1484
+ output_attentions: Optional[bool] = None,
1485
+ output_hidden_states: Optional[bool] = None,
1486
+ return_dict: Optional[bool] = None,
1487
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1488
+ r"""
1489
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1490
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1491
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1492
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1493
+ """
1494
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1495
+
1496
+ transformer_outputs = self.transformer(
1497
+ input_ids,
1498
+ past_key_values=past_key_values,
1499
+ attention_mask=attention_mask,
1500
+ token_type_ids=token_type_ids,
1501
+ position_ids=position_ids,
1502
+ head_mask=head_mask,
1503
+ inputs_embeds=inputs_embeds,
1504
+ use_cache=use_cache,
1505
+ output_attentions=output_attentions,
1506
+ output_hidden_states=output_hidden_states,
1507
+ return_dict=return_dict,
1508
+ )
1509
+ hidden_states = transformer_outputs[0]
1510
+ logits = self.score(hidden_states)
1511
+
1512
+ if input_ids is not None:
1513
+ batch_size, sequence_length = input_ids.shape[:2]
1514
+ else:
1515
+ batch_size, sequence_length = inputs_embeds.shape[:2]
1516
+
1517
+ assert (
1518
+ self.config.pad_token_id is not None or batch_size == 1
1519
+ ), "Cannot handle batch sizes > 1 if no padding token is defined."
1520
+ if self.config.pad_token_id is None:
1521
+ sequence_lengths = -1
1522
+ else:
1523
+ if input_ids is not None:
1524
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1525
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1526
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1527
+ sequence_lengths = sequence_lengths.to(logits.device)
1528
+ else:
1529
+ sequence_lengths = -1
1530
+ logger.warning_once(
1531
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1532
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1533
+ )
1534
+
1535
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1536
+
1537
+ loss = None
1538
+ if labels is not None:
1539
+ if self.config.problem_type is None:
1540
+ if self.num_labels == 1:
1541
+ self.config.problem_type = "regression"
1542
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1543
+ self.config.problem_type = "single_label_classification"
1544
+ else:
1545
+ self.config.problem_type = "multi_label_classification"
1546
+
1547
+ if self.config.problem_type == "regression":
1548
+ loss_fct = MSELoss()
1549
+ if self.num_labels == 1:
1550
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1551
+ else:
1552
+ loss = loss_fct(pooled_logits, labels)
1553
+ elif self.config.problem_type == "single_label_classification":
1554
+ loss_fct = CrossEntropyLoss()
1555
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1556
+ elif self.config.problem_type == "multi_label_classification":
1557
+ loss_fct = BCEWithLogitsLoss()
1558
+ loss = loss_fct(pooled_logits, labels)
1559
+ if not return_dict:
1560
+ output = (pooled_logits,) + transformer_outputs[1:]
1561
+ return ((loss,) + output) if loss is not None else output
1562
+
1563
+ return SequenceClassifierOutputWithPast(
1564
+ loss=loss,
1565
+ logits=pooled_logits,
1566
+ past_key_values=transformer_outputs.past_key_values,
1567
+ hidden_states=transformer_outputs.hidden_states,
1568
+ attentions=transformer_outputs.attentions,
1569
+ )
1570
+
1571
+
1572
+ @add_start_docstrings(
1573
+ """
1574
+ GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1575
+ Named-Entity-Recognition (NER) tasks.
1576
+ """,
1577
+ GPT2_START_DOCSTRING,
1578
+ )
1579
+ class SVD_GPT2ForTokenClassification(SVD_GPT2PreTrainedModel):
1580
+ def __init__(self, config):
1581
+ super().__init__(config)
1582
+ self.num_labels = config.num_labels
1583
+
1584
+ self.transformer = SVD_GPT2Model(config)
1585
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1586
+ classifier_dropout = config.classifier_dropout
1587
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1588
+ classifier_dropout = config.hidden_dropout
1589
+ else:
1590
+ classifier_dropout = 0.1
1591
+ self.dropout = nn.Dropout(classifier_dropout)
1592
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1593
+
1594
+ # Model parallel
1595
+ self.model_parallel = False
1596
+ self.device_map = None
1597
+
1598
+ # Initialize weights and apply final processing
1599
+ self.post_init()
1600
+
1601
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1602
+ # fmt: off
1603
+ @add_code_sample_docstrings(
1604
+ checkpoint="brad1141/gpt2-finetuned-comp2",
1605
+ output_type=TokenClassifierOutput,
1606
+ config_class=_CONFIG_FOR_DOC,
1607
+ expected_loss=0.25,
1608
+ expected_output=[
1609
+ "Lead",
1610
+ "Lead",
1611
+ "Lead",
1612
+ "Position",
1613
+ "Lead",
1614
+ "Lead",
1615
+ "Lead",
1616
+ "Lead",
1617
+ "Lead",
1618
+ "Lead",
1619
+ "Lead",
1620
+ "Lead",
1621
+ ],
1622
+ )
1623
+ # fmt: on
1624
+ def forward(
1625
+ self,
1626
+ input_ids: Optional[torch.LongTensor] = None,
1627
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1628
+ attention_mask: Optional[torch.FloatTensor] = None,
1629
+ token_type_ids: Optional[torch.LongTensor] = None,
1630
+ position_ids: Optional[torch.LongTensor] = None,
1631
+ head_mask: Optional[torch.FloatTensor] = None,
1632
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1633
+ labels: Optional[torch.LongTensor] = None,
1634
+ use_cache: Optional[bool] = None,
1635
+ output_attentions: Optional[bool] = None,
1636
+ output_hidden_states: Optional[bool] = None,
1637
+ return_dict: Optional[bool] = None,
1638
+ ) -> Union[Tuple, TokenClassifierOutput]:
1639
+ r"""
1640
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1641
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1642
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1643
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1644
+ """
1645
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1646
+
1647
+ transformer_outputs = self.transformer(
1648
+ input_ids,
1649
+ past_key_values=past_key_values,
1650
+ attention_mask=attention_mask,
1651
+ token_type_ids=token_type_ids,
1652
+ position_ids=position_ids,
1653
+ head_mask=head_mask,
1654
+ inputs_embeds=inputs_embeds,
1655
+ use_cache=use_cache,
1656
+ output_attentions=output_attentions,
1657
+ output_hidden_states=output_hidden_states,
1658
+ return_dict=return_dict,
1659
+ )
1660
+
1661
+ hidden_states = transformer_outputs[0]
1662
+ hidden_states = self.dropout(hidden_states)
1663
+ logits = self.classifier(hidden_states)
1664
+
1665
+ loss = None
1666
+ if labels is not None:
1667
+ labels = labels.to(logits.device)
1668
+ loss_fct = CrossEntropyLoss()
1669
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1670
+
1671
+ if not return_dict:
1672
+ output = (logits,) + transformer_outputs[2:]
1673
+ return ((loss,) + output) if loss is not None else output
1674
+
1675
+ return TokenClassifierOutput(
1676
+ loss=loss,
1677
+ logits=logits,
1678
+ hidden_states=transformer_outputs.hidden_states,
1679
+ attentions=transformer_outputs.attentions,
1680
+ )
1681
+
1682
+
1683
+ @add_start_docstrings(
1684
+ """
1685
+ The GPT-2 Model transformer with a span classification head on top for extractive question-answering tasks like
1686
+ SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1687
+ """,
1688
+ GPT2_START_DOCSTRING,
1689
+ )
1690
+ class SVD_GPT2ForQuestionAnswering(SVD_GPT2PreTrainedModel):
1691
+ def __init__(self, config):
1692
+ super().__init__(config)
1693
+ self.num_labels = config.num_labels
1694
+ self.transformer = SVD_GPT2Model(config)
1695
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1696
+
1697
+ # Model parallel
1698
+ self.model_parallel = False
1699
+ self.device_map = None
1700
+
1701
+ # Initialize weights and apply final processing
1702
+ self.post_init()
1703
+
1704
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1705
+ @add_code_sample_docstrings(
1706
+ checkpoint=_CHECKPOINT_FOR_DOC,
1707
+ output_type=QuestionAnsweringModelOutput,
1708
+ config_class=_CONFIG_FOR_DOC,
1709
+ real_checkpoint=_CHECKPOINT_FOR_DOC,
1710
+ )
1711
+ def forward(
1712
+ self,
1713
+ input_ids: Optional[torch.LongTensor] = None,
1714
+ attention_mask: Optional[torch.FloatTensor] = None,
1715
+ token_type_ids: Optional[torch.LongTensor] = None,
1716
+ position_ids: Optional[torch.LongTensor] = None,
1717
+ head_mask: Optional[torch.FloatTensor] = None,
1718
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1719
+ start_positions: Optional[torch.LongTensor] = None,
1720
+ end_positions: Optional[torch.LongTensor] = None,
1721
+ output_attentions: Optional[bool] = None,
1722
+ output_hidden_states: Optional[bool] = None,
1723
+ return_dict: Optional[bool] = None,
1724
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1725
+ r"""
1726
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1727
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1728
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1729
+ are not taken into account for computing the loss.
1730
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1731
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1732
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1733
+ are not taken into account for computing the loss.
1734
+ """
1735
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1736
+
1737
+ outputs = self.transformer(
1738
+ input_ids,
1739
+ attention_mask=attention_mask,
1740
+ token_type_ids=token_type_ids,
1741
+ position_ids=position_ids,
1742
+ head_mask=head_mask,
1743
+ inputs_embeds=inputs_embeds,
1744
+ output_attentions=output_attentions,
1745
+ output_hidden_states=output_hidden_states,
1746
+ return_dict=return_dict,
1747
+ )
1748
+
1749
+ sequence_output = outputs[0]
1750
+
1751
+ logits = self.qa_outputs(sequence_output)
1752
+ start_logits, end_logits = logits.split(1, dim=-1)
1753
+ start_logits = start_logits.squeeze(-1).contiguous()
1754
+ end_logits = end_logits.squeeze(-1).contiguous()
1755
+
1756
+ total_loss = None
1757
+ if start_positions is not None and end_positions is not None:
1758
+ # If we are on multi-GPU, split add a dimension
1759
+ if len(start_positions.size()) > 1:
1760
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
1761
+ if len(end_positions.size()) > 1:
1762
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
1763
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1764
+ ignored_index = start_logits.size(1)
1765
+ start_positions = start_positions.clamp(0, ignored_index)
1766
+ end_positions = end_positions.clamp(0, ignored_index)
1767
+
1768
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1769
+ start_loss = loss_fct(start_logits, start_positions)
1770
+ end_loss = loss_fct(end_logits, end_positions)
1771
+ total_loss = (start_loss + end_loss) / 2
1772
+
1773
+ if not return_dict:
1774
+ output = (start_logits, end_logits) + outputs[2:]
1775
+ return ((total_loss,) + output) if total_loss is not None else output
1776
+
1777
+ return QuestionAnsweringModelOutput(
1778
+ loss=total_loss,
1779
+ start_logits=start_logits,
1780
+ end_logits=end_logits,
1781
+ hidden_states=outputs.hidden_states,
1782
+ attentions=outputs.attentions,
1783
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "pad_token": "<|endoftext|>",
5
+ "unk_token": "<|endoftext|>"
6
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "50256": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ }
12
+ },
13
+ "bos_token": "<|endoftext|>",
14
+ "clean_up_tokenization_spaces": false,
15
+ "eos_token": "<|endoftext|>",
16
+ "extra_special_tokens": {},
17
+ "model_max_length": 1024,
18
+ "pad_token": "<|endoftext|>",
19
+ "tokenizer_class": "GPT2Tokenizer",
20
+ "unk_token": "<|endoftext|>"
21
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff