abdelrahmane01 commited on
Commit
8018928
·
verified ·
1 Parent(s): 8bc0950

Upload 2 files

Browse files
Models/25/configuration_phi3.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Phi-3 model configuration"""
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class Phi3Config(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3
28
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
29
+ defaults will yield a similar configuration to that of the
30
+ [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct).
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 32064):
37
+ Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the
38
+ `inputs_ids` passed when calling [`Phi3Model`].
39
+ hidden_size (`int`, *optional*, defaults to 3072):
40
+ Dimension of the hidden representations.
41
+ intermediate_size (`int`, *optional*, defaults to 8192):
42
+ Dimension of the MLP representations.
43
+ num_hidden_layers (`int`, *optional*, defaults to 32):
44
+ Number of hidden layers in the Transformer decoder.
45
+ num_attention_heads (`int`, *optional*, defaults to 32):
46
+ Number of attention heads for each attention layer in the Transformer decoder.
47
+ num_key_value_heads (`int`, *optional*):
48
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
49
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
50
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
51
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
52
+ by meanpooling all the original heads within that group. For more details checkout [this
53
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
54
+ `num_attention_heads`.
55
+ resid_pdrop (`float`, *optional*, defaults to 0.0):
56
+ Dropout probability for mlp outputs.
57
+ embd_pdrop (`int`, *optional*, defaults to 0.0):
58
+ The dropout ratio for the embeddings.
59
+ attention_dropout (`float`, *optional*, defaults to 0.0):
60
+ The dropout ratio after computing the attention scores.
61
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
62
+ The non-linear activation function (function or string) in the decoder.
63
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
64
+ The maximum sequence length that this model might ever be used with.
65
+ original_max_position_embeddings (`int`, *optional*, defaults to 4096):
66
+ The maximum sequence length that this model was trained with. This is used to determine the size of the
67
+ original RoPE embeddings when using long scaling.
68
+ initializer_range (`float`, *optional*, defaults to 0.02):
69
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
70
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
71
+ The epsilon value used for the RMSNorm.
72
+ use_cache (`bool`, *optional*, defaults to `True`):
73
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
74
+ relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
75
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
76
+ Whether to tie weight embeddings
77
+ rope_theta (`float`, *optional*, defaults to 10000.0):
78
+ The base period of the RoPE embeddings.
79
+ rope_scaling (`dict`, *optional*):
80
+ The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
81
+ contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and
82
+ the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
83
+ divided by the number of attention heads divided by 2.
84
+ partial_rotary_factor (`float`, *optional*, defaults to 1.0):
85
+ Percentage of the query and keys which will have rotary embedding. Must be between 0.0 and 1.0.
86
+ bos_token_id (`int`, *optional*, defaults to 1):
87
+ The id of the "beginning-of-sequence" token.
88
+ eos_token_id (`int`, *optional*, defaults to 32000):
89
+ The id of the "end-of-sequence" token.
90
+ pad_token_id (`int`, *optional*, defaults to 32000):
91
+ The id of the padding token.
92
+ sliding_window (`int`, *optional*):
93
+ Sliding window attention window size. If `None`, no sliding window is applied.
94
+
95
+ Example:
96
+
97
+ ```python
98
+ >>> from transformers import Phi3Model, Phi3Config
99
+
100
+ >>> # Initializing a Phi-3 style configuration
101
+ >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
102
+
103
+ >>> # Initializing a model from the configuration
104
+ >>> model = Phi3Model(configuration)
105
+
106
+ >>> # Accessing the model configuration
107
+ >>> configuration = model.config
108
+ ```"""
109
+
110
+ model_type = "phi3"
111
+ keys_to_ignore_at_inference = ["past_key_values"]
112
+
113
+ def __init__(
114
+ self,
115
+ vocab_size=32064,
116
+ hidden_size=3072,
117
+ intermediate_size=8192,
118
+ num_hidden_layers=32,
119
+ num_attention_heads=32,
120
+ num_key_value_heads=None,
121
+ resid_pdrop=0.0,
122
+ embd_pdrop=0.0,
123
+ attention_dropout=0.0,
124
+ hidden_act="silu",
125
+ max_position_embeddings=4096,
126
+ original_max_position_embeddings=4096,
127
+ initializer_range=0.02,
128
+ rms_norm_eps=1e-5,
129
+ use_cache=True,
130
+ tie_word_embeddings=False,
131
+ rope_theta=10000.0,
132
+ rope_scaling=None,
133
+ partial_rotary_factor=1.0,
134
+ bos_token_id=1,
135
+ eos_token_id=32000,
136
+ pad_token_id=32000,
137
+ sliding_window=None,
138
+ **kwargs,
139
+ ):
140
+ self.vocab_size = vocab_size
141
+ self.hidden_size = hidden_size
142
+ self.intermediate_size = intermediate_size
143
+ self.num_hidden_layers = num_hidden_layers
144
+ self.num_attention_heads = num_attention_heads
145
+
146
+ if num_key_value_heads is None:
147
+ num_key_value_heads = num_attention_heads
148
+
149
+ self.num_key_value_heads = num_key_value_heads
150
+ self.resid_pdrop = resid_pdrop
151
+ self.embd_pdrop = embd_pdrop
152
+ self.attention_dropout = attention_dropout
153
+ self.hidden_act = hidden_act
154
+ self.max_position_embeddings = max_position_embeddings
155
+ self.original_max_position_embeddings = original_max_position_embeddings
156
+ self.initializer_range = initializer_range
157
+ self.rms_norm_eps = rms_norm_eps
158
+ self.use_cache = use_cache
159
+ self.rope_theta = rope_theta
160
+ self.rope_scaling = rope_scaling
161
+ self.partial_rotary_factor = partial_rotary_factor
162
+ self._rope_scaling_adjustment()
163
+ self._rope_scaling_validation()
164
+ self.sliding_window = sliding_window
165
+
166
+ super().__init__(
167
+ bos_token_id=bos_token_id,
168
+ eos_token_id=eos_token_id,
169
+ pad_token_id=pad_token_id,
170
+ tie_word_embeddings=tie_word_embeddings,
171
+ **kwargs,
172
+ )
173
+
174
+ def _rope_scaling_adjustment(self):
175
+ """
176
+ Adjust the `type` of the `rope_scaling` configuration for backward compatibility.
177
+ """
178
+ if self.rope_scaling is None:
179
+ return
180
+
181
+ rope_scaling_type = self.rope_scaling.get("type", None)
182
+
183
+ # For backward compatibility if previous version used "su" or "yarn"
184
+ if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]:
185
+ self.rope_scaling["type"] = "longrope"
186
+
187
+ def _rope_scaling_validation(self):
188
+ """
189
+ Validate the `rope_scaling` configuration.
190
+ """
191
+ if self.rope_scaling is None:
192
+ return
193
+
194
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3:
195
+ raise ValueError(
196
+ "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, "
197
+ f"got {self.rope_scaling}"
198
+ )
199
+ rope_scaling_type = self.rope_scaling.get("type", None)
200
+ rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
201
+ rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
202
+ if rope_scaling_type is None or rope_scaling_type not in ["longrope"]:
203
+ raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}")
204
+ if not (
205
+ isinstance(rope_scaling_short_factor, list)
206
+ and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
207
+ ):
208
+ raise ValueError(
209
+ f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
210
+ )
211
+ rotary_ndims = int(self.hidden_size // self.num_attention_heads * self.partial_rotary_factor)
212
+ if not len(rope_scaling_short_factor) == rotary_ndims // 2:
213
+ raise ValueError(
214
+ f"`rope_scaling`'s short_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_short_factor)}"
215
+ )
216
+ if not (
217
+ isinstance(rope_scaling_long_factor, list)
218
+ and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
219
+ ):
220
+ raise ValueError(
221
+ f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
222
+ )
223
+ if not len(rope_scaling_long_factor) == rotary_ndims // 2:
224
+ raise ValueError(
225
+ f"`rope_scaling`'s long_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_long_factor)}"
226
+ )
Models/25/modeling_phi3.py ADDED
@@ -0,0 +1,1185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """PyTorch Phi-3 model."""
17
+
18
+ from typing import Callable, List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+ from transformers.activations import ACT2FN
24
+ from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
25
+ from transformers.generation import GenerationMixin
26
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
27
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPast,
30
+ CausalLMOutputWithPast,
31
+ SequenceClassifierOutputWithPast,
32
+ TokenClassifierOutput,
33
+ )
34
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
35
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
+ from transformers.processing_utils import Unpack
37
+ from transformers.utils import (
38
+ add_code_sample_docstrings,
39
+ add_start_docstrings,
40
+ add_start_docstrings_to_model_forward,
41
+ logging,
42
+ replace_return_docstrings,
43
+ )
44
+
45
+ # Robust import for LossKwargs
46
+ try:
47
+ from transformers.utils import LossKwargs
48
+ except ImportError:
49
+ from transformers.utils import TransformersKwargs as LossKwargs
50
+ from transformers.utils.deprecation import deprecate_kwarg
51
+ from .configuration_phi3 import Phi3Config
52
+
53
+
54
+ logger = logging.get_logger(__name__)
55
+
56
+ _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
57
+ _CONFIG_FOR_DOC = "Phi3Config"
58
+
59
+
60
+ class Phi3MLP(nn.Module):
61
+ def __init__(self, config):
62
+ super().__init__()
63
+
64
+ self.config = config
65
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
66
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
67
+ self.activation_fn = ACT2FN[config.hidden_act]
68
+
69
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
70
+ up_states = self.gate_up_proj(hidden_states)
71
+
72
+ gate, up_states = up_states.chunk(2, dim=-1)
73
+ up_states = up_states * self.activation_fn(gate)
74
+
75
+ return self.down_proj(up_states)
76
+
77
+
78
+ def rotate_half(x):
79
+ """Rotates half the hidden dims of the input."""
80
+ x1 = x[..., : x.shape[-1] // 2]
81
+ x2 = x[..., x.shape[-1] // 2 :]
82
+ return torch.cat((-x2, x1), dim=-1)
83
+
84
+
85
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
86
+ """
87
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
88
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
89
+ """
90
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
91
+ if n_rep == 1:
92
+ return hidden_states
93
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
94
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
95
+
96
+
97
+ def eager_attention_forward(
98
+ module: nn.Module,
99
+ query: torch.Tensor,
100
+ key: torch.Tensor,
101
+ value: torch.Tensor,
102
+ attention_mask: Optional[torch.Tensor],
103
+ scaling: float,
104
+ dropout: float = 0.0,
105
+ **kwargs,
106
+ ):
107
+ key_states = repeat_kv(key, module.num_key_value_groups)
108
+ value_states = repeat_kv(value, module.num_key_value_groups)
109
+
110
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
111
+ if attention_mask is not None:
112
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
113
+ attn_weights = attn_weights + causal_mask
114
+
115
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
116
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
117
+ attn_output = torch.matmul(attn_weights, value_states)
118
+ attn_output = attn_output.transpose(1, 2).contiguous()
119
+
120
+ return attn_output, attn_weights
121
+
122
+
123
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
124
+ """Applies Rotary Position Embedding to the query and key tensors.
125
+
126
+ Args:
127
+ q (`torch.Tensor`): The query tensor.
128
+ k (`torch.Tensor`): The key tensor.
129
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
130
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
131
+ position_ids (`torch.Tensor`, *optional*):
132
+ Deprecated and unused.
133
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
134
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
135
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
136
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
137
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
138
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
139
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
140
+ Returns:
141
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
142
+ """
143
+ cos = cos.unsqueeze(unsqueeze_dim)
144
+ sin = sin.unsqueeze(unsqueeze_dim)
145
+
146
+ rotary_dim = cos.shape[-1]
147
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
148
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
149
+
150
+ q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1)
151
+ k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1)
152
+ return q_embed, k_embed
153
+
154
+
155
+ class Phi3Attention(nn.Module):
156
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
157
+
158
+ def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None):
159
+ super().__init__()
160
+ self.config = config
161
+ self.layer_idx = layer_idx
162
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
163
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
164
+ self.num_key_value_heads = config.num_key_value_heads
165
+ self.scaling = self.head_dim**-0.5
166
+ self.attention_dropout = config.attention_dropout
167
+ self.is_causal = True
168
+
169
+ op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim)
170
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
171
+ self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False)
172
+
173
+ def forward(
174
+ self,
175
+ hidden_states: torch.Tensor,
176
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
177
+ attention_mask: Optional[torch.Tensor],
178
+ past_key_value: Optional[Cache] = None,
179
+ cache_position: Optional[torch.LongTensor] = None,
180
+ **kwargs: Unpack[FlashAttentionKwargs],
181
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
182
+ input_shape = hidden_states.shape[:-1]
183
+ hidden_shape = (*input_shape, -1, self.head_dim)
184
+
185
+ qkv = self.qkv_proj(hidden_states)
186
+ query_pos = self.config.num_attention_heads * self.head_dim
187
+ query_states = qkv[..., :query_pos]
188
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
189
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
190
+
191
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
192
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
193
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
194
+
195
+ cos, sin = position_embeddings
196
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
197
+
198
+ if past_key_value is not None:
199
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
200
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
201
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
202
+
203
+ attention_interface: Callable = eager_attention_forward
204
+ if self.config._attn_implementation != "eager":
205
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
206
+ logger.warning_once(
207
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
208
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
209
+ )
210
+ else:
211
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
212
+
213
+ attn_output, attn_weights = attention_interface(
214
+ self,
215
+ query_states,
216
+ key_states,
217
+ value_states,
218
+ attention_mask,
219
+ dropout=0.0 if not self.training else self.attention_dropout,
220
+ scaling=self.scaling,
221
+ sliding_window=getattr(self.config, "sliding_window", None),
222
+ **kwargs,
223
+ )
224
+
225
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
226
+ attn_output = self.o_proj(attn_output)
227
+ return attn_output, attn_weights
228
+
229
+
230
+ class Phi3RMSNorm(nn.Module):
231
+ def __init__(self, hidden_size, eps=1e-6):
232
+ """
233
+ Phi3RMSNorm is equivalent to T5LayerNorm
234
+ """
235
+ super().__init__()
236
+ self.weight = nn.Parameter(torch.ones(hidden_size))
237
+ self.variance_epsilon = eps
238
+
239
+ def forward(self, hidden_states):
240
+ input_dtype = hidden_states.dtype
241
+ hidden_states = hidden_states.to(torch.float32)
242
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
243
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
244
+ return self.weight * hidden_states.to(input_dtype)
245
+
246
+ def extra_repr(self):
247
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
248
+
249
+
250
+ class Phi3DecoderLayer(nn.Module):
251
+ def __init__(self, config: Phi3Config, layer_idx: int):
252
+ super().__init__()
253
+ self.hidden_size = config.hidden_size
254
+ self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx)
255
+ self.mlp = Phi3MLP(config)
256
+ self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
257
+ self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
258
+ self.config = config
259
+ self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
260
+ self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
261
+
262
+ def forward(
263
+ self,
264
+ hidden_states: torch.Tensor,
265
+ attention_mask: Optional[torch.Tensor] = None,
266
+ position_ids: Optional[torch.LongTensor] = None,
267
+ past_key_value: Optional[Cache] = None,
268
+ output_attentions: Optional[bool] = False,
269
+ use_cache: Optional[bool] = False,
270
+ cache_position: Optional[torch.LongTensor] = None,
271
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
272
+ **kwargs: Unpack[FlashAttentionKwargs],
273
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
274
+ """
275
+ Args:
276
+ hidden_states (`torch.FloatTensor`):
277
+ input to the layer of shape `(batch, seq_len, embed_dim)`
278
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
279
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
280
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
281
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
282
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
283
+ past_key_value (`Cache`, *optional*): cached past key and value projection states
284
+ output_attentions (`bool`, *optional*):
285
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
286
+ returned tensors for more detail.
287
+ use_cache (`bool`, *optional*):
288
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
289
+ (see `past_key_values`).
290
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
291
+ Indices depicting the position of the input sequence tokens in the sequence
292
+ kwargs (`dict`, *optional*):
293
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
294
+ into the model
295
+ """
296
+ residual = hidden_states
297
+
298
+ hidden_states = self.input_layernorm(hidden_states)
299
+
300
+ # Self Attention
301
+ hidden_states, self_attn_weights = self.self_attn(
302
+ hidden_states=hidden_states,
303
+ attention_mask=attention_mask,
304
+ position_ids=position_ids,
305
+ past_key_value=past_key_value,
306
+ output_attentions=output_attentions,
307
+ use_cache=use_cache,
308
+ cache_position=cache_position,
309
+ position_embeddings=position_embeddings,
310
+ **kwargs,
311
+ )
312
+ hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama
313
+
314
+ residual = hidden_states
315
+ hidden_states = self.post_attention_layernorm(hidden_states)
316
+ hidden_states = self.mlp(hidden_states)
317
+ hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama
318
+
319
+ outputs = (hidden_states,)
320
+ if output_attentions:
321
+ outputs += (self_attn_weights,)
322
+
323
+ return outputs
324
+
325
+
326
+ class Phi3RotaryEmbedding(nn.Module):
327
+ def __init__(self, config: Phi3Config, device=None):
328
+ super().__init__()
329
+ # BC: "rope_type" was originally "type"
330
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
331
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
332
+ else:
333
+ self.rope_type = "default"
334
+ self.max_seq_len_cached = config.max_position_embeddings
335
+ self.original_max_seq_len = config.max_position_embeddings
336
+
337
+ self.config = config
338
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
339
+
340
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
341
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
342
+ self.original_inv_freq = self.inv_freq
343
+
344
+ def _dynamic_frequency_update(self, position_ids, device):
345
+ """
346
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
347
+ 1 - growing beyond the cached sequence length (allow scaling)
348
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
349
+ """
350
+ seq_len = torch.max(position_ids) + 1
351
+ if seq_len > self.max_seq_len_cached: # growth
352
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
353
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
354
+ self.max_seq_len_cached = seq_len
355
+
356
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
357
+ # This .to() is needed if the model has been moved to a device after being initialized (because
358
+ # the buffer is automatically moved, but not the original copy)
359
+ self.original_inv_freq = self.original_inv_freq.to(device)
360
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
361
+ self.max_seq_len_cached = self.original_max_seq_len
362
+
363
+ @torch.no_grad()
364
+ def forward(self, x, position_ids):
365
+ if "dynamic" in self.rope_type:
366
+ self._dynamic_frequency_update(position_ids, device=x.device)
367
+ elif self.rope_type == "longrope":
368
+ self._longrope_frequency_update(position_ids, device=x.device)
369
+
370
+ # Core RoPE block
371
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
372
+ position_ids_expanded = position_ids[:, None, :].float()
373
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
374
+ device_type = x.device.type
375
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
376
+ with torch.autocast(device_type=device_type, enabled=False):
377
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
378
+ emb = torch.cat((freqs, freqs), dim=-1)
379
+ cos = emb.cos()
380
+ sin = emb.sin()
381
+
382
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
383
+ cos = cos * self.attention_scaling
384
+ sin = sin * self.attention_scaling
385
+
386
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
387
+
388
+ def _longrope_frequency_update(self, position_ids, device):
389
+ """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
390
+ seq_len = torch.max(position_ids) + 1
391
+ if hasattr(self.config, "original_max_position_embeddings"):
392
+ original_max_position_embeddings = self.config.original_max_position_embeddings
393
+ else:
394
+ original_max_position_embeddings = self.config.max_position_embeddings
395
+ if seq_len > original_max_position_embeddings:
396
+ if not hasattr(self, "long_inv_freq"):
397
+ self.long_inv_freq, _ = self.rope_init_fn(
398
+ self.config, device, seq_len=original_max_position_embeddings + 1
399
+ )
400
+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
401
+ else:
402
+ # This .to() is needed if the model has been moved to a device after being initialized (because
403
+ # the buffer is automatically moved, but not the original copy)
404
+ self.original_inv_freq = self.original_inv_freq.to(device)
405
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
406
+
407
+
408
+ PHI3_START_DOCSTRING = r"""
409
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
410
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
411
+ etc.)
412
+
413
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
414
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
415
+ and behavior.
416
+
417
+ Parameters:
418
+ config ([`Phi3Config`]):
419
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
420
+ load the weights associated with the model, only the configuration. Check out the
421
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
422
+ """
423
+
424
+
425
+ @add_start_docstrings(
426
+ "The bare Phi3 Model outputting raw hidden-states without any specific head on top.",
427
+ PHI3_START_DOCSTRING,
428
+ )
429
+ class Phi3PreTrainedModel(PreTrainedModel):
430
+ config_class = Phi3Config
431
+ base_model_prefix = "model"
432
+ supports_gradient_checkpointing = True
433
+ _no_split_modules = ["Phi3DecoderLayer"]
434
+ _skip_keys_device_placement = ["past_key_values"]
435
+ _supports_flash_attn_2 = True
436
+ _supports_sdpa = True
437
+ _supports_flex_attn = True
438
+ _supports_cache_class = True
439
+ _supports_quantized_cache = True
440
+ _supports_static_cache = True
441
+ _supports_attention_backend = True
442
+ _version = "0.0.5"
443
+
444
+ def _init_weights(self, module):
445
+ std = self.config.initializer_range
446
+ if isinstance(module, nn.Linear):
447
+ module.weight.data.normal_(mean=0.0, std=std)
448
+ if module.bias is not None:
449
+ module.bias.data.zero_()
450
+ elif isinstance(module, nn.Embedding):
451
+ module.weight.data.normal_(mean=0.0, std=std)
452
+ if module.padding_idx is not None:
453
+ module.weight.data[module.padding_idx].zero_()
454
+
455
+
456
+ PHI3_INPUTS_DOCSTRING = r"""
457
+ Args:
458
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
459
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
460
+ it.
461
+
462
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
463
+ [`PreTrainedTokenizer.__call__`] for details.
464
+
465
+ [What are input IDs?](../glossary#input-ids)
466
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
467
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
468
+
469
+ - 1 for tokens that are **not masked**,
470
+ - 0 for tokens that are **masked**.
471
+
472
+ [What are attention masks?](../glossary#attention-mask)
473
+
474
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
475
+ [`PreTrainedTokenizer.__call__`] for details.
476
+
477
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
478
+ `past_key_values`).
479
+
480
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
481
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
482
+ information on the default strategy.
483
+
484
+ - 1 indicates the head is **not masked**,
485
+ - 0 indicates the head is **masked**.
486
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
487
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
488
+ config.n_positions - 1]`.
489
+
490
+ [What are position IDs?](../glossary#position-ids)
491
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
492
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
493
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
494
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
495
+
496
+ Two formats are allowed:
497
+ - a [`~cache_utils.Cache`] instance, see our
498
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
499
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
500
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
501
+ cache format.
502
+
503
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
504
+ legacy cache format will be returned.
505
+
506
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
507
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
508
+ of shape `(batch_size, sequence_length)`.
509
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
510
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
511
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
512
+ model's internal embedding lookup matrix.
513
+ use_cache (`bool`, *optional*):
514
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
515
+ `past_key_values`).
516
+ output_attentions (`bool`, *optional*):
517
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
518
+ tensors for more detail.
519
+ output_hidden_states (`bool`, *optional*):
520
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
521
+ more detail.
522
+ return_dict (`bool`, *optional*):
523
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
524
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
525
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
526
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
527
+ the complete sequence length.
528
+ """
529
+
530
+
531
+ @add_start_docstrings(
532
+ "The bare Phi3 Model outputting raw hidden-states without any specific head on top.",
533
+ PHI3_START_DOCSTRING,
534
+ )
535
+ class Phi3Model(Phi3PreTrainedModel):
536
+ """
537
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
538
+
539
+ Args:
540
+ config: Phi3Config
541
+ """
542
+
543
+ def __init__(self, config: Phi3Config):
544
+ super().__init__(config)
545
+ self.padding_idx = config.pad_token_id
546
+ self.vocab_size = config.vocab_size
547
+
548
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
549
+ self.layers = nn.ModuleList(
550
+ [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
551
+ )
552
+ self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
553
+ self.rotary_emb = Phi3RotaryEmbedding(config=config)
554
+ self.gradient_checkpointing = False
555
+
556
+ # Initialize weights and apply final processing
557
+ self.post_init()
558
+
559
+ def get_input_embeddings(self):
560
+ return self.embed_tokens
561
+
562
+ def set_input_embeddings(self, value):
563
+ self.embed_tokens = value
564
+
565
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
566
+ def forward(
567
+ self,
568
+ input_ids: torch.LongTensor = None,
569
+ attention_mask: Optional[torch.Tensor] = None,
570
+ position_ids: Optional[torch.LongTensor] = None,
571
+ past_key_values: Optional[Cache] = None,
572
+ inputs_embeds: Optional[torch.FloatTensor] = None,
573
+ use_cache: Optional[bool] = None,
574
+ output_attentions: Optional[bool] = None,
575
+ output_hidden_states: Optional[bool] = None,
576
+ return_dict: Optional[bool] = None,
577
+ cache_position: Optional[torch.LongTensor] = None,
578
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
579
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
580
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
581
+ output_hidden_states = (
582
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
583
+ )
584
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
585
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
586
+
587
+ if (input_ids is None) ^ (inputs_embeds is not None):
588
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
589
+
590
+ if self.gradient_checkpointing and self.training and use_cache:
591
+ logger.warning_once(
592
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
593
+ )
594
+ use_cache = False
595
+
596
+ if inputs_embeds is None:
597
+ inputs_embeds = self.embed_tokens(input_ids)
598
+
599
+ if use_cache and past_key_values is None:
600
+ past_key_values = DynamicCache()
601
+
602
+ if cache_position is None:
603
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
604
+ cache_position = torch.arange(
605
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
606
+ )
607
+
608
+ if position_ids is None:
609
+ position_ids = cache_position.unsqueeze(0)
610
+
611
+ causal_mask = self._update_causal_mask(
612
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
613
+ )
614
+
615
+ hidden_states = inputs_embeds
616
+
617
+ # create position embeddings to be shared across the decoder layers
618
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
619
+
620
+ # decoder layers
621
+ all_hidden_states = () if output_hidden_states else None
622
+ all_self_attns = () if output_attentions else None
623
+
624
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
625
+ if output_hidden_states:
626
+ all_hidden_states += (hidden_states,)
627
+
628
+ if self.gradient_checkpointing and self.training:
629
+ layer_outputs = self._gradient_checkpointing_func(
630
+ decoder_layer.__call__,
631
+ hidden_states,
632
+ causal_mask,
633
+ position_ids,
634
+ past_key_values,
635
+ output_attentions,
636
+ use_cache,
637
+ cache_position,
638
+ position_embeddings,
639
+ )
640
+ else:
641
+ layer_outputs = decoder_layer(
642
+ hidden_states,
643
+ attention_mask=causal_mask,
644
+ position_ids=position_ids,
645
+ past_key_value=past_key_values,
646
+ output_attentions=output_attentions,
647
+ use_cache=use_cache,
648
+ cache_position=cache_position,
649
+ position_embeddings=position_embeddings,
650
+ **flash_attn_kwargs,
651
+ )
652
+
653
+ hidden_states = layer_outputs[0]
654
+
655
+ if output_attentions:
656
+ all_self_attns += (layer_outputs[1],)
657
+
658
+ hidden_states = self.norm(hidden_states)
659
+
660
+ # add hidden states from the last decoder layer
661
+ if output_hidden_states:
662
+ all_hidden_states += (hidden_states,)
663
+
664
+ output = BaseModelOutputWithPast(
665
+ last_hidden_state=hidden_states,
666
+ past_key_values=past_key_values if use_cache else None,
667
+ hidden_states=all_hidden_states,
668
+ attentions=all_self_attns,
669
+ )
670
+ return output if return_dict else output.to_tuple()
671
+
672
+ def _update_causal_mask(
673
+ self,
674
+ attention_mask: torch.Tensor,
675
+ input_tensor: torch.Tensor,
676
+ cache_position: torch.Tensor,
677
+ past_key_values: Cache,
678
+ output_attentions: bool,
679
+ ):
680
+ if self.config._attn_implementation == "flash_attention_2":
681
+ if attention_mask is not None and past_key_values is not None:
682
+ is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
683
+ if is_padding_right:
684
+ raise ValueError(
685
+ "You are attempting to perform batched generation with padding_side='right'"
686
+ " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to "
687
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
688
+ )
689
+ if attention_mask is not None and 0.0 in attention_mask:
690
+ return attention_mask
691
+ return None
692
+
693
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
694
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
695
+ # to infer the attention mask.
696
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
697
+ using_static_cache = isinstance(past_key_values, StaticCache)
698
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
699
+
700
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
701
+ if (
702
+ self.config._attn_implementation == "sdpa"
703
+ and not (using_static_cache or using_sliding_window_cache)
704
+ and not output_attentions
705
+ ):
706
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
707
+ attention_mask,
708
+ inputs_embeds=input_tensor,
709
+ past_key_values_length=past_seen_tokens,
710
+ sliding_window=self.config.sliding_window,
711
+ is_training=self.training,
712
+ ):
713
+ return None
714
+
715
+ dtype, device = input_tensor.dtype, input_tensor.device
716
+ min_dtype = torch.finfo(dtype).min
717
+ sequence_length = input_tensor.shape[1]
718
+ # SlidingWindowCache or StaticCache
719
+ if using_sliding_window_cache or using_static_cache:
720
+ target_length = past_key_values.get_max_cache_shape()
721
+ # DynamicCache or no cache
722
+ else:
723
+ target_length = (
724
+ attention_mask.shape[-1]
725
+ if isinstance(attention_mask, torch.Tensor)
726
+ else past_seen_tokens + sequence_length + 1
727
+ )
728
+
729
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
730
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
731
+ attention_mask,
732
+ sequence_length=sequence_length,
733
+ target_length=target_length,
734
+ dtype=dtype,
735
+ device=device,
736
+ cache_position=cache_position,
737
+ batch_size=input_tensor.shape[0],
738
+ config=self.config,
739
+ past_key_values=past_key_values,
740
+ )
741
+
742
+ if (
743
+ self.config._attn_implementation == "sdpa"
744
+ and attention_mask is not None
745
+ and attention_mask.device.type in ["cuda", "xpu"]
746
+ and not output_attentions
747
+ ):
748
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
749
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
750
+ # Details: https://github.com/pytorch/pytorch/issues/110213
751
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
752
+
753
+ return causal_mask
754
+
755
+ @staticmethod
756
+ def _prepare_4d_causal_attention_mask_with_cache_position(
757
+ attention_mask: torch.Tensor,
758
+ sequence_length: int,
759
+ target_length: int,
760
+ dtype: torch.dtype,
761
+ device: torch.device,
762
+ cache_position: torch.Tensor,
763
+ batch_size: int,
764
+ config: Phi3Config,
765
+ past_key_values: Cache,
766
+ ):
767
+ """
768
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
769
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
770
+
771
+ Args:
772
+ attention_mask (`torch.Tensor`):
773
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
774
+ sequence_length (`int`):
775
+ The sequence length being processed.
776
+ target_length (`int`):
777
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
778
+ dtype (`torch.dtype`):
779
+ The dtype to use for the 4D attention mask.
780
+ device (`torch.device`):
781
+ The device to plcae the 4D attention mask on.
782
+ cache_position (`torch.Tensor`):
783
+ Indices depicting the position of the input sequence tokens in the sequence.
784
+ batch_size (`torch.Tensor`):
785
+ Batch size.
786
+ config (`Phi3Config`):
787
+ The model's configuration class
788
+ past_key_values (`Cache`):
789
+ The cache class that is being used currently to generate
790
+ """
791
+ if attention_mask is not None and attention_mask.dim() == 4:
792
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
793
+ causal_mask = attention_mask
794
+ else:
795
+ min_dtype = torch.finfo(dtype).min
796
+ causal_mask = torch.full(
797
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
798
+ )
799
+ diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
800
+ if config.sliding_window is not None:
801
+ # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
802
+ # the check is needed to verify is current checkpoint was trained with sliding window or not
803
+ if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
804
+ sliding_attend_mask = torch.arange(target_length, device=device) <= (
805
+ cache_position.reshape(-1, 1) - config.sliding_window
806
+ )
807
+ diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
808
+ causal_mask *= diagonal_attend_mask
809
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
810
+ if attention_mask is not None:
811
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
812
+ if attention_mask.shape[-1] > target_length:
813
+ attention_mask = attention_mask[:, :target_length]
814
+ mask_length = attention_mask.shape[-1]
815
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
816
+ causal_mask.device
817
+ )
818
+ padding_mask = padding_mask == 0
819
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
820
+ padding_mask, min_dtype
821
+ )
822
+ return causal_mask
823
+
824
+
825
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
826
+
827
+
828
+ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
829
+ _tied_weights_keys = ["lm_head.weight"]
830
+ _tp_plan = {"lm_head": "colwise_rep"}
831
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
832
+
833
+ def __init__(self, config):
834
+ super().__init__(config)
835
+ self.model = Phi3Model(config)
836
+ self.vocab_size = config.vocab_size
837
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
838
+
839
+ # Initialize weights and apply final processing
840
+ self.post_init()
841
+
842
+ def get_input_embeddings(self):
843
+ return self.model.embed_tokens
844
+
845
+ def set_input_embeddings(self, value):
846
+ self.model.embed_tokens = value
847
+
848
+ def get_output_embeddings(self):
849
+ return self.lm_head
850
+
851
+ def set_output_embeddings(self, new_embeddings):
852
+ self.lm_head = new_embeddings
853
+
854
+ def set_decoder(self, decoder):
855
+ self.model = decoder
856
+
857
+ def get_decoder(self):
858
+ return self.model
859
+
860
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
861
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
862
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
863
+ def forward(
864
+ self,
865
+ input_ids: torch.LongTensor = None,
866
+ attention_mask: Optional[torch.Tensor] = None,
867
+ position_ids: Optional[torch.LongTensor] = None,
868
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
869
+ inputs_embeds: Optional[torch.FloatTensor] = None,
870
+ labels: Optional[torch.LongTensor] = None,
871
+ use_cache: Optional[bool] = None,
872
+ output_attentions: Optional[bool] = None,
873
+ output_hidden_states: Optional[bool] = None,
874
+ return_dict: Optional[bool] = None,
875
+ cache_position: Optional[torch.LongTensor] = None,
876
+ logits_to_keep: Union[int, torch.Tensor] = 0,
877
+ **kwargs: Unpack[KwargsForCausalLM],
878
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
879
+ r"""
880
+ Args:
881
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
882
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
883
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
884
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
885
+
886
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
887
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
888
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
889
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
890
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
891
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
892
+
893
+ Returns:
894
+
895
+ Example:
896
+
897
+ ```python
898
+ >>> from transformers import AutoTokenizer, Phi3ForCausalLM
899
+
900
+ >>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf")
901
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf")
902
+
903
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
904
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
905
+
906
+ >>> # Generate
907
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
908
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
909
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
910
+ ```"""
911
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
912
+ output_hidden_states = (
913
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
914
+ )
915
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
916
+
917
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
918
+ outputs = self.model(
919
+ input_ids=input_ids,
920
+ attention_mask=attention_mask,
921
+ position_ids=position_ids,
922
+ past_key_values=past_key_values,
923
+ inputs_embeds=inputs_embeds,
924
+ use_cache=use_cache,
925
+ output_attentions=output_attentions,
926
+ output_hidden_states=output_hidden_states,
927
+ return_dict=return_dict,
928
+ cache_position=cache_position,
929
+ **kwargs,
930
+ )
931
+
932
+ hidden_states = outputs[0]
933
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
934
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
935
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
936
+
937
+ loss = None
938
+ if labels is not None:
939
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
940
+
941
+ if not return_dict:
942
+ output = (logits,) + outputs[1:]
943
+ return (loss,) + output if loss is not None else output
944
+
945
+ return CausalLMOutputWithPast(
946
+ loss=loss,
947
+ logits=logits,
948
+ past_key_values=outputs.past_key_values,
949
+ hidden_states=outputs.hidden_states,
950
+ attentions=outputs.attentions,
951
+ )
952
+
953
+ def prepare_inputs_for_generation(
954
+ self,
955
+ input_ids,
956
+ past_key_values=None,
957
+ attention_mask=None,
958
+ inputs_embeds=None,
959
+ cache_position=None,
960
+ position_ids=None,
961
+ use_cache=True,
962
+ logits_to_keep=None,
963
+ **kwargs,
964
+ ):
965
+ # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
966
+ # process
967
+
968
+ # When the first time input length reached long and short factor switching point, enforce re-compute cache
969
+ # It will cause downside of slower at this single token position, however, better than current failure.
970
+ if (
971
+ past_key_values
972
+ and self.config.rope_scaling
973
+ and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
974
+ ):
975
+ past_length = cache_position[0]
976
+ if past_length <= self.config.original_max_position_embeddings:
977
+ past_key_values = None
978
+
979
+ model_inputs = super().prepare_inputs_for_generation(
980
+ input_ids=input_ids,
981
+ past_key_values=past_key_values,
982
+ attention_mask=attention_mask,
983
+ inputs_embeds=inputs_embeds,
984
+ cache_position=cache_position,
985
+ position_ids=position_ids,
986
+ use_cache=use_cache,
987
+ logits_to_keep=logits_to_keep,
988
+ **kwargs,
989
+ )
990
+ return model_inputs
991
+
992
+
993
+ @add_start_docstrings(
994
+ """
995
+ The Phi3 Model transformer with a sequence classification head on top (linear layer).
996
+
997
+ [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
998
+ (e.g. GPT-2) do.
999
+
1000
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1001
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1002
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1003
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1004
+ each row of the batch).
1005
+ """,
1006
+ PHI3_START_DOCSTRING,
1007
+ )
1008
+ class Phi3ForSequenceClassification(Phi3PreTrainedModel):
1009
+ def __init__(self, config):
1010
+ super().__init__(config)
1011
+ self.num_labels = config.num_labels
1012
+ self.model = Phi3Model(config)
1013
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1014
+
1015
+ # Initialize weights and apply final processing
1016
+ self.post_init()
1017
+
1018
+ def get_input_embeddings(self):
1019
+ return self.model.embed_tokens
1020
+
1021
+ def set_input_embeddings(self, value):
1022
+ self.model.embed_tokens = value
1023
+
1024
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1025
+ def forward(
1026
+ self,
1027
+ input_ids: Optional[torch.LongTensor] = None,
1028
+ attention_mask: Optional[torch.Tensor] = None,
1029
+ position_ids: Optional[torch.LongTensor] = None,
1030
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1031
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1032
+ labels: Optional[torch.LongTensor] = None,
1033
+ use_cache: Optional[bool] = None,
1034
+ output_attentions: Optional[bool] = None,
1035
+ output_hidden_states: Optional[bool] = None,
1036
+ return_dict: Optional[bool] = None,
1037
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1038
+ r"""
1039
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1040
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1041
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1042
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1043
+ """
1044
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1045
+
1046
+ transformer_outputs = self.model(
1047
+ input_ids,
1048
+ attention_mask=attention_mask,
1049
+ position_ids=position_ids,
1050
+ past_key_values=past_key_values,
1051
+ inputs_embeds=inputs_embeds,
1052
+ use_cache=use_cache,
1053
+ output_attentions=output_attentions,
1054
+ output_hidden_states=output_hidden_states,
1055
+ return_dict=return_dict,
1056
+ )
1057
+ hidden_states = transformer_outputs[0]
1058
+ logits = self.score(hidden_states)
1059
+
1060
+ if input_ids is not None:
1061
+ batch_size = input_ids.shape[0]
1062
+ else:
1063
+ batch_size = inputs_embeds.shape[0]
1064
+
1065
+ if self.config.pad_token_id is None and batch_size != 1:
1066
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1067
+ if self.config.pad_token_id is None:
1068
+ last_non_pad_token = -1
1069
+ elif input_ids is not None:
1070
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
1071
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
1072
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
1073
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
1074
+ else:
1075
+ last_non_pad_token = -1
1076
+ logger.warning_once(
1077
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1078
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1079
+ )
1080
+
1081
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
1082
+
1083
+ loss = None
1084
+ if labels is not None:
1085
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
1086
+
1087
+ if not return_dict:
1088
+ output = (pooled_logits,) + transformer_outputs[1:]
1089
+ return ((loss,) + output) if loss is not None else output
1090
+
1091
+ return SequenceClassifierOutputWithPast(
1092
+ loss=loss,
1093
+ logits=pooled_logits,
1094
+ past_key_values=transformer_outputs.past_key_values,
1095
+ hidden_states=transformer_outputs.hidden_states,
1096
+ attentions=transformer_outputs.attentions,
1097
+ )
1098
+
1099
+
1100
+ @add_start_docstrings(
1101
+ """
1102
+ The Phi3 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
1103
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
1104
+ """,
1105
+ PHI3_START_DOCSTRING,
1106
+ )
1107
+ class Phi3ForTokenClassification(Phi3PreTrainedModel):
1108
+ def __init__(self, config):
1109
+ super().__init__(config)
1110
+ self.num_labels = config.num_labels
1111
+ self.model = Phi3Model(config)
1112
+ if getattr(config, "classifier_dropout", None) is not None:
1113
+ classifier_dropout = config.classifier_dropout
1114
+ elif getattr(config, "hidden_dropout", None) is not None:
1115
+ classifier_dropout = config.hidden_dropout
1116
+ else:
1117
+ classifier_dropout = 0.1
1118
+ self.dropout = nn.Dropout(classifier_dropout)
1119
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
1120
+
1121
+ # Initialize weights and apply final processing
1122
+ self.post_init()
1123
+
1124
+ def get_input_embeddings(self):
1125
+ return self.model.embed_tokens
1126
+
1127
+ def set_input_embeddings(self, value):
1128
+ self.model.embed_tokens = value
1129
+
1130
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1131
+ @add_code_sample_docstrings(
1132
+ checkpoint=_CHECKPOINT_FOR_DOC,
1133
+ output_type=TokenClassifierOutput,
1134
+ config_class=_CONFIG_FOR_DOC,
1135
+ )
1136
+ def forward(
1137
+ self,
1138
+ input_ids: Optional[torch.LongTensor] = None,
1139
+ attention_mask: Optional[torch.Tensor] = None,
1140
+ position_ids: Optional[torch.LongTensor] = None,
1141
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1142
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1143
+ labels: Optional[torch.LongTensor] = None,
1144
+ use_cache: Optional[bool] = None,
1145
+ output_attentions: Optional[bool] = None,
1146
+ output_hidden_states: Optional[bool] = None,
1147
+ return_dict: Optional[bool] = None,
1148
+ ) -> Union[Tuple, TokenClassifierOutput]:
1149
+ r"""
1150
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1151
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1152
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1153
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1154
+ """
1155
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1156
+
1157
+ outputs = self.model(
1158
+ input_ids,
1159
+ attention_mask=attention_mask,
1160
+ position_ids=position_ids,
1161
+ past_key_values=past_key_values,
1162
+ inputs_embeds=inputs_embeds,
1163
+ use_cache=use_cache,
1164
+ output_attentions=output_attentions,
1165
+ output_hidden_states=output_hidden_states,
1166
+ return_dict=return_dict,
1167
+ )
1168
+ sequence_output = outputs[0]
1169
+ sequence_output = self.dropout(sequence_output)
1170
+ logits = self.score(sequence_output)
1171
+
1172
+ loss = None
1173
+ if labels is not None:
1174
+ loss = self.loss_function(logits, labels, self.config)
1175
+
1176
+ if not return_dict:
1177
+ output = (logits,) + outputs[2:]
1178
+ return ((loss,) + output) if loss is not None else output
1179
+
1180
+ return TokenClassifierOutput(
1181
+ loss=loss,
1182
+ logits=logits,
1183
+ hidden_states=outputs.hidden_states,
1184
+ attentions=outputs.attentions,
1185
+ )