abdelrahmane01 commited on
Commit
28704fc
·
verified ·
1 Parent(s): 8018928

Delete Models/25/modeling_phi3.py

Browse files
Files changed (1) hide show
  1. Models/25/modeling_phi3.py +0 -1185
Models/25/modeling_phi3.py DELETED
@@ -1,1185 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- """PyTorch Phi-3 model."""
17
-
18
- from typing import Callable, List, Optional, Tuple, Union
19
-
20
- import torch
21
- from torch import nn
22
-
23
- from transformers.activations import ACT2FN
24
- from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
25
- from transformers.generation import GenerationMixin
26
- from transformers.modeling_attn_mask_utils import AttentionMaskConverter
27
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
28
- from transformers.modeling_outputs import (
29
- BaseModelOutputWithPast,
30
- CausalLMOutputWithPast,
31
- SequenceClassifierOutputWithPast,
32
- TokenClassifierOutput,
33
- )
34
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
35
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
- from transformers.processing_utils import Unpack
37
- from transformers.utils import (
38
- add_code_sample_docstrings,
39
- add_start_docstrings,
40
- add_start_docstrings_to_model_forward,
41
- logging,
42
- replace_return_docstrings,
43
- )
44
-
45
- # Robust import for LossKwargs
46
- try:
47
- from transformers.utils import LossKwargs
48
- except ImportError:
49
- from transformers.utils import TransformersKwargs as LossKwargs
50
- from transformers.utils.deprecation import deprecate_kwarg
51
- from .configuration_phi3 import Phi3Config
52
-
53
-
54
- logger = logging.get_logger(__name__)
55
-
56
- _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
57
- _CONFIG_FOR_DOC = "Phi3Config"
58
-
59
-
60
- class Phi3MLP(nn.Module):
61
- def __init__(self, config):
62
- super().__init__()
63
-
64
- self.config = config
65
- self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
66
- self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
67
- self.activation_fn = ACT2FN[config.hidden_act]
68
-
69
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
70
- up_states = self.gate_up_proj(hidden_states)
71
-
72
- gate, up_states = up_states.chunk(2, dim=-1)
73
- up_states = up_states * self.activation_fn(gate)
74
-
75
- return self.down_proj(up_states)
76
-
77
-
78
- def rotate_half(x):
79
- """Rotates half the hidden dims of the input."""
80
- x1 = x[..., : x.shape[-1] // 2]
81
- x2 = x[..., x.shape[-1] // 2 :]
82
- return torch.cat((-x2, x1), dim=-1)
83
-
84
-
85
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
86
- """
87
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
88
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
89
- """
90
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
91
- if n_rep == 1:
92
- return hidden_states
93
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
94
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
95
-
96
-
97
- def eager_attention_forward(
98
- module: nn.Module,
99
- query: torch.Tensor,
100
- key: torch.Tensor,
101
- value: torch.Tensor,
102
- attention_mask: Optional[torch.Tensor],
103
- scaling: float,
104
- dropout: float = 0.0,
105
- **kwargs,
106
- ):
107
- key_states = repeat_kv(key, module.num_key_value_groups)
108
- value_states = repeat_kv(value, module.num_key_value_groups)
109
-
110
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
111
- if attention_mask is not None:
112
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
113
- attn_weights = attn_weights + causal_mask
114
-
115
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
116
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
117
- attn_output = torch.matmul(attn_weights, value_states)
118
- attn_output = attn_output.transpose(1, 2).contiguous()
119
-
120
- return attn_output, attn_weights
121
-
122
-
123
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
124
- """Applies Rotary Position Embedding to the query and key tensors.
125
-
126
- Args:
127
- q (`torch.Tensor`): The query tensor.
128
- k (`torch.Tensor`): The key tensor.
129
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
130
- sin (`torch.Tensor`): The sine part of the rotary embedding.
131
- position_ids (`torch.Tensor`, *optional*):
132
- Deprecated and unused.
133
- unsqueeze_dim (`int`, *optional*, defaults to 1):
134
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
135
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
136
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
137
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
138
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
139
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
140
- Returns:
141
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
142
- """
143
- cos = cos.unsqueeze(unsqueeze_dim)
144
- sin = sin.unsqueeze(unsqueeze_dim)
145
-
146
- rotary_dim = cos.shape[-1]
147
- q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
148
- k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
149
-
150
- q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1)
151
- k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1)
152
- return q_embed, k_embed
153
-
154
-
155
- class Phi3Attention(nn.Module):
156
- """Multi-headed attention from 'Attention Is All You Need' paper"""
157
-
158
- def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None):
159
- super().__init__()
160
- self.config = config
161
- self.layer_idx = layer_idx
162
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
163
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
164
- self.num_key_value_heads = config.num_key_value_heads
165
- self.scaling = self.head_dim**-0.5
166
- self.attention_dropout = config.attention_dropout
167
- self.is_causal = True
168
-
169
- op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim)
170
- self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
171
- self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False)
172
-
173
- def forward(
174
- self,
175
- hidden_states: torch.Tensor,
176
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
177
- attention_mask: Optional[torch.Tensor],
178
- past_key_value: Optional[Cache] = None,
179
- cache_position: Optional[torch.LongTensor] = None,
180
- **kwargs: Unpack[FlashAttentionKwargs],
181
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
182
- input_shape = hidden_states.shape[:-1]
183
- hidden_shape = (*input_shape, -1, self.head_dim)
184
-
185
- qkv = self.qkv_proj(hidden_states)
186
- query_pos = self.config.num_attention_heads * self.head_dim
187
- query_states = qkv[..., :query_pos]
188
- key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
189
- value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
190
-
191
- query_states = query_states.view(hidden_shape).transpose(1, 2)
192
- key_states = key_states.view(hidden_shape).transpose(1, 2)
193
- value_states = value_states.view(hidden_shape).transpose(1, 2)
194
-
195
- cos, sin = position_embeddings
196
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
197
-
198
- if past_key_value is not None:
199
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
200
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
201
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
202
-
203
- attention_interface: Callable = eager_attention_forward
204
- if self.config._attn_implementation != "eager":
205
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
206
- logger.warning_once(
207
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
208
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
209
- )
210
- else:
211
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
212
-
213
- attn_output, attn_weights = attention_interface(
214
- self,
215
- query_states,
216
- key_states,
217
- value_states,
218
- attention_mask,
219
- dropout=0.0 if not self.training else self.attention_dropout,
220
- scaling=self.scaling,
221
- sliding_window=getattr(self.config, "sliding_window", None),
222
- **kwargs,
223
- )
224
-
225
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
226
- attn_output = self.o_proj(attn_output)
227
- return attn_output, attn_weights
228
-
229
-
230
- class Phi3RMSNorm(nn.Module):
231
- def __init__(self, hidden_size, eps=1e-6):
232
- """
233
- Phi3RMSNorm is equivalent to T5LayerNorm
234
- """
235
- super().__init__()
236
- self.weight = nn.Parameter(torch.ones(hidden_size))
237
- self.variance_epsilon = eps
238
-
239
- def forward(self, hidden_states):
240
- input_dtype = hidden_states.dtype
241
- hidden_states = hidden_states.to(torch.float32)
242
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
243
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
244
- return self.weight * hidden_states.to(input_dtype)
245
-
246
- def extra_repr(self):
247
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
248
-
249
-
250
- class Phi3DecoderLayer(nn.Module):
251
- def __init__(self, config: Phi3Config, layer_idx: int):
252
- super().__init__()
253
- self.hidden_size = config.hidden_size
254
- self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx)
255
- self.mlp = Phi3MLP(config)
256
- self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
257
- self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
258
- self.config = config
259
- self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
260
- self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
261
-
262
- def forward(
263
- self,
264
- hidden_states: torch.Tensor,
265
- attention_mask: Optional[torch.Tensor] = None,
266
- position_ids: Optional[torch.LongTensor] = None,
267
- past_key_value: Optional[Cache] = None,
268
- output_attentions: Optional[bool] = False,
269
- use_cache: Optional[bool] = False,
270
- cache_position: Optional[torch.LongTensor] = None,
271
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
272
- **kwargs: Unpack[FlashAttentionKwargs],
273
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
274
- """
275
- Args:
276
- hidden_states (`torch.FloatTensor`):
277
- input to the layer of shape `(batch, seq_len, embed_dim)`
278
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
279
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
280
- position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
281
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
282
- `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
283
- past_key_value (`Cache`, *optional*): cached past key and value projection states
284
- output_attentions (`bool`, *optional*):
285
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
286
- returned tensors for more detail.
287
- use_cache (`bool`, *optional*):
288
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
289
- (see `past_key_values`).
290
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
291
- Indices depicting the position of the input sequence tokens in the sequence
292
- kwargs (`dict`, *optional*):
293
- Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
294
- into the model
295
- """
296
- residual = hidden_states
297
-
298
- hidden_states = self.input_layernorm(hidden_states)
299
-
300
- # Self Attention
301
- hidden_states, self_attn_weights = self.self_attn(
302
- hidden_states=hidden_states,
303
- attention_mask=attention_mask,
304
- position_ids=position_ids,
305
- past_key_value=past_key_value,
306
- output_attentions=output_attentions,
307
- use_cache=use_cache,
308
- cache_position=cache_position,
309
- position_embeddings=position_embeddings,
310
- **kwargs,
311
- )
312
- hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama
313
-
314
- residual = hidden_states
315
- hidden_states = self.post_attention_layernorm(hidden_states)
316
- hidden_states = self.mlp(hidden_states)
317
- hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama
318
-
319
- outputs = (hidden_states,)
320
- if output_attentions:
321
- outputs += (self_attn_weights,)
322
-
323
- return outputs
324
-
325
-
326
- class Phi3RotaryEmbedding(nn.Module):
327
- def __init__(self, config: Phi3Config, device=None):
328
- super().__init__()
329
- # BC: "rope_type" was originally "type"
330
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
331
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
332
- else:
333
- self.rope_type = "default"
334
- self.max_seq_len_cached = config.max_position_embeddings
335
- self.original_max_seq_len = config.max_position_embeddings
336
-
337
- self.config = config
338
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
339
-
340
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
341
- self.register_buffer("inv_freq", inv_freq, persistent=False)
342
- self.original_inv_freq = self.inv_freq
343
-
344
- def _dynamic_frequency_update(self, position_ids, device):
345
- """
346
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
347
- 1 - growing beyond the cached sequence length (allow scaling)
348
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
349
- """
350
- seq_len = torch.max(position_ids) + 1
351
- if seq_len > self.max_seq_len_cached: # growth
352
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
353
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
354
- self.max_seq_len_cached = seq_len
355
-
356
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
357
- # This .to() is needed if the model has been moved to a device after being initialized (because
358
- # the buffer is automatically moved, but not the original copy)
359
- self.original_inv_freq = self.original_inv_freq.to(device)
360
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
361
- self.max_seq_len_cached = self.original_max_seq_len
362
-
363
- @torch.no_grad()
364
- def forward(self, x, position_ids):
365
- if "dynamic" in self.rope_type:
366
- self._dynamic_frequency_update(position_ids, device=x.device)
367
- elif self.rope_type == "longrope":
368
- self._longrope_frequency_update(position_ids, device=x.device)
369
-
370
- # Core RoPE block
371
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
372
- position_ids_expanded = position_ids[:, None, :].float()
373
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
374
- device_type = x.device.type
375
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
376
- with torch.autocast(device_type=device_type, enabled=False):
377
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
378
- emb = torch.cat((freqs, freqs), dim=-1)
379
- cos = emb.cos()
380
- sin = emb.sin()
381
-
382
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
383
- cos = cos * self.attention_scaling
384
- sin = sin * self.attention_scaling
385
-
386
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
387
-
388
- def _longrope_frequency_update(self, position_ids, device):
389
- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
390
- seq_len = torch.max(position_ids) + 1
391
- if hasattr(self.config, "original_max_position_embeddings"):
392
- original_max_position_embeddings = self.config.original_max_position_embeddings
393
- else:
394
- original_max_position_embeddings = self.config.max_position_embeddings
395
- if seq_len > original_max_position_embeddings:
396
- if not hasattr(self, "long_inv_freq"):
397
- self.long_inv_freq, _ = self.rope_init_fn(
398
- self.config, device, seq_len=original_max_position_embeddings + 1
399
- )
400
- self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
401
- else:
402
- # This .to() is needed if the model has been moved to a device after being initialized (because
403
- # the buffer is automatically moved, but not the original copy)
404
- self.original_inv_freq = self.original_inv_freq.to(device)
405
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
406
-
407
-
408
- PHI3_START_DOCSTRING = r"""
409
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
410
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
411
- etc.)
412
-
413
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
414
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
415
- and behavior.
416
-
417
- Parameters:
418
- config ([`Phi3Config`]):
419
- Model configuration class with all the parameters of the model. Initializing with a config file does not
420
- load the weights associated with the model, only the configuration. Check out the
421
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
422
- """
423
-
424
-
425
- @add_start_docstrings(
426
- "The bare Phi3 Model outputting raw hidden-states without any specific head on top.",
427
- PHI3_START_DOCSTRING,
428
- )
429
- class Phi3PreTrainedModel(PreTrainedModel):
430
- config_class = Phi3Config
431
- base_model_prefix = "model"
432
- supports_gradient_checkpointing = True
433
- _no_split_modules = ["Phi3DecoderLayer"]
434
- _skip_keys_device_placement = ["past_key_values"]
435
- _supports_flash_attn_2 = True
436
- _supports_sdpa = True
437
- _supports_flex_attn = True
438
- _supports_cache_class = True
439
- _supports_quantized_cache = True
440
- _supports_static_cache = True
441
- _supports_attention_backend = True
442
- _version = "0.0.5"
443
-
444
- def _init_weights(self, module):
445
- std = self.config.initializer_range
446
- if isinstance(module, nn.Linear):
447
- module.weight.data.normal_(mean=0.0, std=std)
448
- if module.bias is not None:
449
- module.bias.data.zero_()
450
- elif isinstance(module, nn.Embedding):
451
- module.weight.data.normal_(mean=0.0, std=std)
452
- if module.padding_idx is not None:
453
- module.weight.data[module.padding_idx].zero_()
454
-
455
-
456
- PHI3_INPUTS_DOCSTRING = r"""
457
- Args:
458
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
459
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
460
- it.
461
-
462
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
463
- [`PreTrainedTokenizer.__call__`] for details.
464
-
465
- [What are input IDs?](../glossary#input-ids)
466
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
467
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
468
-
469
- - 1 for tokens that are **not masked**,
470
- - 0 for tokens that are **masked**.
471
-
472
- [What are attention masks?](../glossary#attention-mask)
473
-
474
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
475
- [`PreTrainedTokenizer.__call__`] for details.
476
-
477
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
478
- `past_key_values`).
479
-
480
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
481
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
482
- information on the default strategy.
483
-
484
- - 1 indicates the head is **not masked**,
485
- - 0 indicates the head is **masked**.
486
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
487
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
488
- config.n_positions - 1]`.
489
-
490
- [What are position IDs?](../glossary#position-ids)
491
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
492
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
493
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
494
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
495
-
496
- Two formats are allowed:
497
- - a [`~cache_utils.Cache`] instance, see our
498
- [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
499
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
500
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
501
- cache format.
502
-
503
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
504
- legacy cache format will be returned.
505
-
506
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
507
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
508
- of shape `(batch_size, sequence_length)`.
509
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
510
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
511
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
512
- model's internal embedding lookup matrix.
513
- use_cache (`bool`, *optional*):
514
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
515
- `past_key_values`).
516
- output_attentions (`bool`, *optional*):
517
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
518
- tensors for more detail.
519
- output_hidden_states (`bool`, *optional*):
520
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
521
- more detail.
522
- return_dict (`bool`, *optional*):
523
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
524
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
525
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
526
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
527
- the complete sequence length.
528
- """
529
-
530
-
531
- @add_start_docstrings(
532
- "The bare Phi3 Model outputting raw hidden-states without any specific head on top.",
533
- PHI3_START_DOCSTRING,
534
- )
535
- class Phi3Model(Phi3PreTrainedModel):
536
- """
537
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
538
-
539
- Args:
540
- config: Phi3Config
541
- """
542
-
543
- def __init__(self, config: Phi3Config):
544
- super().__init__(config)
545
- self.padding_idx = config.pad_token_id
546
- self.vocab_size = config.vocab_size
547
-
548
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
549
- self.layers = nn.ModuleList(
550
- [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
551
- )
552
- self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
553
- self.rotary_emb = Phi3RotaryEmbedding(config=config)
554
- self.gradient_checkpointing = False
555
-
556
- # Initialize weights and apply final processing
557
- self.post_init()
558
-
559
- def get_input_embeddings(self):
560
- return self.embed_tokens
561
-
562
- def set_input_embeddings(self, value):
563
- self.embed_tokens = value
564
-
565
- @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
566
- def forward(
567
- self,
568
- input_ids: torch.LongTensor = None,
569
- attention_mask: Optional[torch.Tensor] = None,
570
- position_ids: Optional[torch.LongTensor] = None,
571
- past_key_values: Optional[Cache] = None,
572
- inputs_embeds: Optional[torch.FloatTensor] = None,
573
- use_cache: Optional[bool] = None,
574
- output_attentions: Optional[bool] = None,
575
- output_hidden_states: Optional[bool] = None,
576
- return_dict: Optional[bool] = None,
577
- cache_position: Optional[torch.LongTensor] = None,
578
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
579
- ) -> Union[Tuple, BaseModelOutputWithPast]:
580
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
581
- output_hidden_states = (
582
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
583
- )
584
- use_cache = use_cache if use_cache is not None else self.config.use_cache
585
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
586
-
587
- if (input_ids is None) ^ (inputs_embeds is not None):
588
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
589
-
590
- if self.gradient_checkpointing and self.training and use_cache:
591
- logger.warning_once(
592
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
593
- )
594
- use_cache = False
595
-
596
- if inputs_embeds is None:
597
- inputs_embeds = self.embed_tokens(input_ids)
598
-
599
- if use_cache and past_key_values is None:
600
- past_key_values = DynamicCache()
601
-
602
- if cache_position is None:
603
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
604
- cache_position = torch.arange(
605
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
606
- )
607
-
608
- if position_ids is None:
609
- position_ids = cache_position.unsqueeze(0)
610
-
611
- causal_mask = self._update_causal_mask(
612
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
613
- )
614
-
615
- hidden_states = inputs_embeds
616
-
617
- # create position embeddings to be shared across the decoder layers
618
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
619
-
620
- # decoder layers
621
- all_hidden_states = () if output_hidden_states else None
622
- all_self_attns = () if output_attentions else None
623
-
624
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
625
- if output_hidden_states:
626
- all_hidden_states += (hidden_states,)
627
-
628
- if self.gradient_checkpointing and self.training:
629
- layer_outputs = self._gradient_checkpointing_func(
630
- decoder_layer.__call__,
631
- hidden_states,
632
- causal_mask,
633
- position_ids,
634
- past_key_values,
635
- output_attentions,
636
- use_cache,
637
- cache_position,
638
- position_embeddings,
639
- )
640
- else:
641
- layer_outputs = decoder_layer(
642
- hidden_states,
643
- attention_mask=causal_mask,
644
- position_ids=position_ids,
645
- past_key_value=past_key_values,
646
- output_attentions=output_attentions,
647
- use_cache=use_cache,
648
- cache_position=cache_position,
649
- position_embeddings=position_embeddings,
650
- **flash_attn_kwargs,
651
- )
652
-
653
- hidden_states = layer_outputs[0]
654
-
655
- if output_attentions:
656
- all_self_attns += (layer_outputs[1],)
657
-
658
- hidden_states = self.norm(hidden_states)
659
-
660
- # add hidden states from the last decoder layer
661
- if output_hidden_states:
662
- all_hidden_states += (hidden_states,)
663
-
664
- output = BaseModelOutputWithPast(
665
- last_hidden_state=hidden_states,
666
- past_key_values=past_key_values if use_cache else None,
667
- hidden_states=all_hidden_states,
668
- attentions=all_self_attns,
669
- )
670
- return output if return_dict else output.to_tuple()
671
-
672
- def _update_causal_mask(
673
- self,
674
- attention_mask: torch.Tensor,
675
- input_tensor: torch.Tensor,
676
- cache_position: torch.Tensor,
677
- past_key_values: Cache,
678
- output_attentions: bool,
679
- ):
680
- if self.config._attn_implementation == "flash_attention_2":
681
- if attention_mask is not None and past_key_values is not None:
682
- is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
683
- if is_padding_right:
684
- raise ValueError(
685
- "You are attempting to perform batched generation with padding_side='right'"
686
- " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to "
687
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
688
- )
689
- if attention_mask is not None and 0.0 in attention_mask:
690
- return attention_mask
691
- return None
692
-
693
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
694
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
695
- # to infer the attention mask.
696
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
697
- using_static_cache = isinstance(past_key_values, StaticCache)
698
- using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
699
-
700
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
701
- if (
702
- self.config._attn_implementation == "sdpa"
703
- and not (using_static_cache or using_sliding_window_cache)
704
- and not output_attentions
705
- ):
706
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
707
- attention_mask,
708
- inputs_embeds=input_tensor,
709
- past_key_values_length=past_seen_tokens,
710
- sliding_window=self.config.sliding_window,
711
- is_training=self.training,
712
- ):
713
- return None
714
-
715
- dtype, device = input_tensor.dtype, input_tensor.device
716
- min_dtype = torch.finfo(dtype).min
717
- sequence_length = input_tensor.shape[1]
718
- # SlidingWindowCache or StaticCache
719
- if using_sliding_window_cache or using_static_cache:
720
- target_length = past_key_values.get_max_cache_shape()
721
- # DynamicCache or no cache
722
- else:
723
- target_length = (
724
- attention_mask.shape[-1]
725
- if isinstance(attention_mask, torch.Tensor)
726
- else past_seen_tokens + sequence_length + 1
727
- )
728
-
729
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
730
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
731
- attention_mask,
732
- sequence_length=sequence_length,
733
- target_length=target_length,
734
- dtype=dtype,
735
- device=device,
736
- cache_position=cache_position,
737
- batch_size=input_tensor.shape[0],
738
- config=self.config,
739
- past_key_values=past_key_values,
740
- )
741
-
742
- if (
743
- self.config._attn_implementation == "sdpa"
744
- and attention_mask is not None
745
- and attention_mask.device.type in ["cuda", "xpu"]
746
- and not output_attentions
747
- ):
748
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
749
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
750
- # Details: https://github.com/pytorch/pytorch/issues/110213
751
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
752
-
753
- return causal_mask
754
-
755
- @staticmethod
756
- def _prepare_4d_causal_attention_mask_with_cache_position(
757
- attention_mask: torch.Tensor,
758
- sequence_length: int,
759
- target_length: int,
760
- dtype: torch.dtype,
761
- device: torch.device,
762
- cache_position: torch.Tensor,
763
- batch_size: int,
764
- config: Phi3Config,
765
- past_key_values: Cache,
766
- ):
767
- """
768
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
769
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
770
-
771
- Args:
772
- attention_mask (`torch.Tensor`):
773
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
774
- sequence_length (`int`):
775
- The sequence length being processed.
776
- target_length (`int`):
777
- The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
778
- dtype (`torch.dtype`):
779
- The dtype to use for the 4D attention mask.
780
- device (`torch.device`):
781
- The device to plcae the 4D attention mask on.
782
- cache_position (`torch.Tensor`):
783
- Indices depicting the position of the input sequence tokens in the sequence.
784
- batch_size (`torch.Tensor`):
785
- Batch size.
786
- config (`Phi3Config`):
787
- The model's configuration class
788
- past_key_values (`Cache`):
789
- The cache class that is being used currently to generate
790
- """
791
- if attention_mask is not None and attention_mask.dim() == 4:
792
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
793
- causal_mask = attention_mask
794
- else:
795
- min_dtype = torch.finfo(dtype).min
796
- causal_mask = torch.full(
797
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
798
- )
799
- diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
800
- if config.sliding_window is not None:
801
- # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
802
- # the check is needed to verify is current checkpoint was trained with sliding window or not
803
- if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
804
- sliding_attend_mask = torch.arange(target_length, device=device) <= (
805
- cache_position.reshape(-1, 1) - config.sliding_window
806
- )
807
- diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
808
- causal_mask *= diagonal_attend_mask
809
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
810
- if attention_mask is not None:
811
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
812
- if attention_mask.shape[-1] > target_length:
813
- attention_mask = attention_mask[:, :target_length]
814
- mask_length = attention_mask.shape[-1]
815
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
816
- causal_mask.device
817
- )
818
- padding_mask = padding_mask == 0
819
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
820
- padding_mask, min_dtype
821
- )
822
- return causal_mask
823
-
824
-
825
- class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
826
-
827
-
828
- class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
829
- _tied_weights_keys = ["lm_head.weight"]
830
- _tp_plan = {"lm_head": "colwise_rep"}
831
- _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
832
-
833
- def __init__(self, config):
834
- super().__init__(config)
835
- self.model = Phi3Model(config)
836
- self.vocab_size = config.vocab_size
837
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
838
-
839
- # Initialize weights and apply final processing
840
- self.post_init()
841
-
842
- def get_input_embeddings(self):
843
- return self.model.embed_tokens
844
-
845
- def set_input_embeddings(self, value):
846
- self.model.embed_tokens = value
847
-
848
- def get_output_embeddings(self):
849
- return self.lm_head
850
-
851
- def set_output_embeddings(self, new_embeddings):
852
- self.lm_head = new_embeddings
853
-
854
- def set_decoder(self, decoder):
855
- self.model = decoder
856
-
857
- def get_decoder(self):
858
- return self.model
859
-
860
- @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
861
- @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
862
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
863
- def forward(
864
- self,
865
- input_ids: torch.LongTensor = None,
866
- attention_mask: Optional[torch.Tensor] = None,
867
- position_ids: Optional[torch.LongTensor] = None,
868
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
869
- inputs_embeds: Optional[torch.FloatTensor] = None,
870
- labels: Optional[torch.LongTensor] = None,
871
- use_cache: Optional[bool] = None,
872
- output_attentions: Optional[bool] = None,
873
- output_hidden_states: Optional[bool] = None,
874
- return_dict: Optional[bool] = None,
875
- cache_position: Optional[torch.LongTensor] = None,
876
- logits_to_keep: Union[int, torch.Tensor] = 0,
877
- **kwargs: Unpack[KwargsForCausalLM],
878
- ) -> Union[Tuple, CausalLMOutputWithPast]:
879
- r"""
880
- Args:
881
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
882
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
883
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
884
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
885
-
886
- logits_to_keep (`int` or `torch.Tensor`, *optional*):
887
- If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
888
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
889
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
890
- If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
891
- This is useful when using packed tensor format (single dimension for batch and sequence length).
892
-
893
- Returns:
894
-
895
- Example:
896
-
897
- ```python
898
- >>> from transformers import AutoTokenizer, Phi3ForCausalLM
899
-
900
- >>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf")
901
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf")
902
-
903
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
904
- >>> inputs = tokenizer(prompt, return_tensors="pt")
905
-
906
- >>> # Generate
907
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
908
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
909
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
910
- ```"""
911
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
912
- output_hidden_states = (
913
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
914
- )
915
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
916
-
917
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
918
- outputs = self.model(
919
- input_ids=input_ids,
920
- attention_mask=attention_mask,
921
- position_ids=position_ids,
922
- past_key_values=past_key_values,
923
- inputs_embeds=inputs_embeds,
924
- use_cache=use_cache,
925
- output_attentions=output_attentions,
926
- output_hidden_states=output_hidden_states,
927
- return_dict=return_dict,
928
- cache_position=cache_position,
929
- **kwargs,
930
- )
931
-
932
- hidden_states = outputs[0]
933
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
934
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
935
- logits = self.lm_head(hidden_states[:, slice_indices, :])
936
-
937
- loss = None
938
- if labels is not None:
939
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
940
-
941
- if not return_dict:
942
- output = (logits,) + outputs[1:]
943
- return (loss,) + output if loss is not None else output
944
-
945
- return CausalLMOutputWithPast(
946
- loss=loss,
947
- logits=logits,
948
- past_key_values=outputs.past_key_values,
949
- hidden_states=outputs.hidden_states,
950
- attentions=outputs.attentions,
951
- )
952
-
953
- def prepare_inputs_for_generation(
954
- self,
955
- input_ids,
956
- past_key_values=None,
957
- attention_mask=None,
958
- inputs_embeds=None,
959
- cache_position=None,
960
- position_ids=None,
961
- use_cache=True,
962
- logits_to_keep=None,
963
- **kwargs,
964
- ):
965
- # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
966
- # process
967
-
968
- # When the first time input length reached long and short factor switching point, enforce re-compute cache
969
- # It will cause downside of slower at this single token position, however, better than current failure.
970
- if (
971
- past_key_values
972
- and self.config.rope_scaling
973
- and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
974
- ):
975
- past_length = cache_position[0]
976
- if past_length <= self.config.original_max_position_embeddings:
977
- past_key_values = None
978
-
979
- model_inputs = super().prepare_inputs_for_generation(
980
- input_ids=input_ids,
981
- past_key_values=past_key_values,
982
- attention_mask=attention_mask,
983
- inputs_embeds=inputs_embeds,
984
- cache_position=cache_position,
985
- position_ids=position_ids,
986
- use_cache=use_cache,
987
- logits_to_keep=logits_to_keep,
988
- **kwargs,
989
- )
990
- return model_inputs
991
-
992
-
993
- @add_start_docstrings(
994
- """
995
- The Phi3 Model transformer with a sequence classification head on top (linear layer).
996
-
997
- [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
998
- (e.g. GPT-2) do.
999
-
1000
- Since it does classification on the last token, it requires to know the position of the last token. If a
1001
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1002
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1003
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1004
- each row of the batch).
1005
- """,
1006
- PHI3_START_DOCSTRING,
1007
- )
1008
- class Phi3ForSequenceClassification(Phi3PreTrainedModel):
1009
- def __init__(self, config):
1010
- super().__init__(config)
1011
- self.num_labels = config.num_labels
1012
- self.model = Phi3Model(config)
1013
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1014
-
1015
- # Initialize weights and apply final processing
1016
- self.post_init()
1017
-
1018
- def get_input_embeddings(self):
1019
- return self.model.embed_tokens
1020
-
1021
- def set_input_embeddings(self, value):
1022
- self.model.embed_tokens = value
1023
-
1024
- @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1025
- def forward(
1026
- self,
1027
- input_ids: Optional[torch.LongTensor] = None,
1028
- attention_mask: Optional[torch.Tensor] = None,
1029
- position_ids: Optional[torch.LongTensor] = None,
1030
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1031
- inputs_embeds: Optional[torch.FloatTensor] = None,
1032
- labels: Optional[torch.LongTensor] = None,
1033
- use_cache: Optional[bool] = None,
1034
- output_attentions: Optional[bool] = None,
1035
- output_hidden_states: Optional[bool] = None,
1036
- return_dict: Optional[bool] = None,
1037
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1038
- r"""
1039
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1040
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1041
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1042
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1043
- """
1044
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1045
-
1046
- transformer_outputs = self.model(
1047
- input_ids,
1048
- attention_mask=attention_mask,
1049
- position_ids=position_ids,
1050
- past_key_values=past_key_values,
1051
- inputs_embeds=inputs_embeds,
1052
- use_cache=use_cache,
1053
- output_attentions=output_attentions,
1054
- output_hidden_states=output_hidden_states,
1055
- return_dict=return_dict,
1056
- )
1057
- hidden_states = transformer_outputs[0]
1058
- logits = self.score(hidden_states)
1059
-
1060
- if input_ids is not None:
1061
- batch_size = input_ids.shape[0]
1062
- else:
1063
- batch_size = inputs_embeds.shape[0]
1064
-
1065
- if self.config.pad_token_id is None and batch_size != 1:
1066
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1067
- if self.config.pad_token_id is None:
1068
- last_non_pad_token = -1
1069
- elif input_ids is not None:
1070
- # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
1071
- non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
1072
- token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
1073
- last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
1074
- else:
1075
- last_non_pad_token = -1
1076
- logger.warning_once(
1077
- f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1078
- "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1079
- )
1080
-
1081
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
1082
-
1083
- loss = None
1084
- if labels is not None:
1085
- loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
1086
-
1087
- if not return_dict:
1088
- output = (pooled_logits,) + transformer_outputs[1:]
1089
- return ((loss,) + output) if loss is not None else output
1090
-
1091
- return SequenceClassifierOutputWithPast(
1092
- loss=loss,
1093
- logits=pooled_logits,
1094
- past_key_values=transformer_outputs.past_key_values,
1095
- hidden_states=transformer_outputs.hidden_states,
1096
- attentions=transformer_outputs.attentions,
1097
- )
1098
-
1099
-
1100
- @add_start_docstrings(
1101
- """
1102
- The Phi3 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
1103
- output) e.g. for Named-Entity-Recognition (NER) tasks.
1104
- """,
1105
- PHI3_START_DOCSTRING,
1106
- )
1107
- class Phi3ForTokenClassification(Phi3PreTrainedModel):
1108
- def __init__(self, config):
1109
- super().__init__(config)
1110
- self.num_labels = config.num_labels
1111
- self.model = Phi3Model(config)
1112
- if getattr(config, "classifier_dropout", None) is not None:
1113
- classifier_dropout = config.classifier_dropout
1114
- elif getattr(config, "hidden_dropout", None) is not None:
1115
- classifier_dropout = config.hidden_dropout
1116
- else:
1117
- classifier_dropout = 0.1
1118
- self.dropout = nn.Dropout(classifier_dropout)
1119
- self.score = nn.Linear(config.hidden_size, config.num_labels)
1120
-
1121
- # Initialize weights and apply final processing
1122
- self.post_init()
1123
-
1124
- def get_input_embeddings(self):
1125
- return self.model.embed_tokens
1126
-
1127
- def set_input_embeddings(self, value):
1128
- self.model.embed_tokens = value
1129
-
1130
- @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1131
- @add_code_sample_docstrings(
1132
- checkpoint=_CHECKPOINT_FOR_DOC,
1133
- output_type=TokenClassifierOutput,
1134
- config_class=_CONFIG_FOR_DOC,
1135
- )
1136
- def forward(
1137
- self,
1138
- input_ids: Optional[torch.LongTensor] = None,
1139
- attention_mask: Optional[torch.Tensor] = None,
1140
- position_ids: Optional[torch.LongTensor] = None,
1141
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1142
- inputs_embeds: Optional[torch.FloatTensor] = None,
1143
- labels: Optional[torch.LongTensor] = None,
1144
- use_cache: Optional[bool] = None,
1145
- output_attentions: Optional[bool] = None,
1146
- output_hidden_states: Optional[bool] = None,
1147
- return_dict: Optional[bool] = None,
1148
- ) -> Union[Tuple, TokenClassifierOutput]:
1149
- r"""
1150
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1151
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1152
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1153
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1154
- """
1155
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1156
-
1157
- outputs = self.model(
1158
- input_ids,
1159
- attention_mask=attention_mask,
1160
- position_ids=position_ids,
1161
- past_key_values=past_key_values,
1162
- inputs_embeds=inputs_embeds,
1163
- use_cache=use_cache,
1164
- output_attentions=output_attentions,
1165
- output_hidden_states=output_hidden_states,
1166
- return_dict=return_dict,
1167
- )
1168
- sequence_output = outputs[0]
1169
- sequence_output = self.dropout(sequence_output)
1170
- logits = self.score(sequence_output)
1171
-
1172
- loss = None
1173
- if labels is not None:
1174
- loss = self.loss_function(logits, labels, self.config)
1175
-
1176
- if not return_dict:
1177
- output = (logits,) + outputs[2:]
1178
- return ((loss,) + output) if loss is not None else output
1179
-
1180
- return TokenClassifierOutput(
1181
- loss=loss,
1182
- logits=logits,
1183
- hidden_states=outputs.hidden_states,
1184
- attentions=outputs.attentions,
1185
- )