Subh775 commited on
Commit
7d62412
·
verified ·
1 Parent(s): 7ae323f

Add modeling_phi.py for self-contained custom code

Browse files
Files changed (1) hide show
  1. modeling_phi.py +1463 -0
modeling_phi.py ADDED
@@ -0,0 +1,1463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """PyTorch Phi model."""
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from packaging import version
24
+ from torch import nn
25
+ from torch.nn import CrossEntropyLoss
26
+
27
+ from transformers.activations import ACT2FN
28
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
29
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutputWithPast,
32
+ CausalLMOutputWithPast,
33
+ )
34
+ from transformers.modeling_utils import PreTrainedModel
35
+ from transformers.utils import (
36
+ add_start_docstrings,
37
+ add_start_docstrings_to_model_forward,
38
+ get_torch_version,
39
+ is_flash_attn_2_available,
40
+ is_flash_attn_greater_or_equal_2_10,
41
+ is_torchdynamo_compiling,
42
+ logging,
43
+ replace_return_docstrings,
44
+ )
45
+ from .configuration_moondream import PhiConfig
46
+
47
+
48
+ if is_flash_attn_2_available():
49
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+ _CONFIG_FOR_DOC = "PhiConfig"
55
+
56
+
57
+ # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
58
+ def _prepare_4d_causal_attention_mask_with_cache_position(
59
+ attention_mask: torch.Tensor,
60
+ sequence_length: int,
61
+ target_length: int,
62
+ dtype: torch.dtype,
63
+ device: torch.device,
64
+ min_dtype: float,
65
+ cache_position: torch.Tensor,
66
+ batch_size: int,
67
+ ):
68
+ """
69
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
70
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
71
+
72
+ Args:
73
+ attention_mask (`torch.Tensor`):
74
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
75
+ sequence_length (`int`):
76
+ The sequence length being processed.
77
+ target_length (`int`):
78
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
79
+ dtype (`torch.dtype`):
80
+ The dtype to use for the 4D attention mask.
81
+ device (`torch.device`):
82
+ The device to plcae the 4D attention mask on.
83
+ min_dtype (`float`):
84
+ The minimum value representable with the dtype `dtype`.
85
+ cache_position (`torch.Tensor`):
86
+ Indices depicting the position of the input sequence tokens in the sequence.
87
+ batch_size (`torch.Tensor`):
88
+ Batch size.
89
+ """
90
+ if attention_mask is not None and attention_mask.dim() == 4:
91
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
92
+ causal_mask = attention_mask
93
+ else:
94
+ causal_mask = torch.full(
95
+ (sequence_length, target_length),
96
+ fill_value=min_dtype,
97
+ dtype=dtype,
98
+ device=device,
99
+ )
100
+ if sequence_length != 1:
101
+ causal_mask = torch.triu(causal_mask, diagonal=1)
102
+ causal_mask *= torch.arange(
103
+ target_length, device=device
104
+ ) > cache_position.reshape(-1, 1)
105
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
106
+ if attention_mask is not None:
107
+ causal_mask = (
108
+ causal_mask.clone()
109
+ ) # copy to contiguous memory for in-place edit
110
+ mask_length = attention_mask.shape[-1]
111
+ padding_mask = (
112
+ causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
113
+ )
114
+ padding_mask = padding_mask == 0
115
+ causal_mask[:, :, :, :mask_length] = causal_mask[
116
+ :, :, :, :mask_length
117
+ ].masked_fill(padding_mask, min_dtype)
118
+
119
+ return causal_mask
120
+
121
+
122
+ # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi
123
+ class PhiRotaryEmbedding(nn.Module):
124
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
125
+ super().__init__()
126
+
127
+ self.dim = dim
128
+ self.max_position_embeddings = max_position_embeddings
129
+ self.base = base
130
+ inv_freq = 1.0 / (
131
+ self.base
132
+ ** (
133
+ torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
134
+ / self.dim
135
+ )
136
+ )
137
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
138
+
139
+ # Build here to make `torch.jit.trace` work.
140
+ self._set_cos_sin_cache(
141
+ seq_len=max_position_embeddings,
142
+ device=self.inv_freq.device,
143
+ dtype=torch.get_default_dtype(),
144
+ )
145
+
146
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
147
+ self.max_seq_len_cached = seq_len
148
+ t = torch.arange(
149
+ self.max_seq_len_cached, device=device, dtype=torch.int64
150
+ ).type_as(self.inv_freq)
151
+
152
+ freqs = torch.outer(t, self.inv_freq)
153
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
154
+ emb = torch.cat((freqs, freqs), dim=-1)
155
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
156
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
157
+
158
+ def forward(self, x, seq_len=None):
159
+ # x: [bs, num_attention_heads, seq_len, head_size]
160
+ if seq_len > self.max_seq_len_cached:
161
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
162
+
163
+ return (
164
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
165
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
166
+ )
167
+
168
+
169
+ # Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Phi
170
+ class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
171
+ """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
172
+
173
+ def __init__(
174
+ self,
175
+ dim,
176
+ max_position_embeddings=2048,
177
+ base=10000,
178
+ device=None,
179
+ scaling_factor=1.0,
180
+ ):
181
+ self.scaling_factor = scaling_factor
182
+ super().__init__(dim, max_position_embeddings, base, device)
183
+
184
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
185
+ self.max_seq_len_cached = seq_len
186
+ t = torch.arange(
187
+ self.max_seq_len_cached, device=device, dtype=torch.int64
188
+ ).type_as(self.inv_freq)
189
+ t = t / self.scaling_factor
190
+
191
+ freqs = torch.outer(t, self.inv_freq)
192
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
193
+ emb = torch.cat((freqs, freqs), dim=-1)
194
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
195
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
196
+
197
+
198
+ # Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Phi
199
+ class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
200
+ """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
201
+
202
+ def __init__(
203
+ self,
204
+ dim,
205
+ max_position_embeddings=2048,
206
+ base=10000,
207
+ device=None,
208
+ scaling_factor=1.0,
209
+ ):
210
+ self.scaling_factor = scaling_factor
211
+ super().__init__(dim, max_position_embeddings, base, device)
212
+
213
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
214
+ self.max_seq_len_cached = seq_len
215
+
216
+ if seq_len > self.max_position_embeddings:
217
+ base = self.base * (
218
+ (self.scaling_factor * seq_len / self.max_position_embeddings)
219
+ - (self.scaling_factor - 1)
220
+ ) ** (self.dim / (self.dim - 2))
221
+ inv_freq = 1.0 / (
222
+ base
223
+ ** (
224
+ torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
225
+ / self.dim
226
+ )
227
+ )
228
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
229
+
230
+ t = torch.arange(
231
+ self.max_seq_len_cached, device=device, dtype=torch.int64
232
+ ).type_as(self.inv_freq)
233
+
234
+ freqs = torch.outer(t, self.inv_freq)
235
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
236
+ emb = torch.cat((freqs, freqs), dim=-1)
237
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
238
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
239
+
240
+
241
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
242
+ def rotate_half(x):
243
+ """Rotates half the hidden dims of the input."""
244
+ x1 = x[..., : x.shape[-1] // 2]
245
+ x2 = x[..., x.shape[-1] // 2 :]
246
+ return torch.cat((-x2, x1), dim=-1)
247
+
248
+
249
+ # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
250
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
251
+ """Applies Rotary Position Embedding to the query and key tensors.
252
+
253
+ Args:
254
+ q (`torch.Tensor`): The query tensor.
255
+ k (`torch.Tensor`): The key tensor.
256
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
257
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
258
+ position_ids (`torch.Tensor`):
259
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
260
+ used to pass offsetted position ids when working with a KV-cache.
261
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
262
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
263
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
264
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
265
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
266
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
267
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
268
+ Returns:
269
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
270
+ """
271
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
272
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
273
+ q_embed = (q * cos) + (rotate_half(q) * sin)
274
+ k_embed = (k * cos) + (rotate_half(k) * sin)
275
+ return q_embed, k_embed
276
+
277
+
278
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi
279
+ class PhiMLP(nn.Module):
280
+ def __init__(self, config):
281
+ super().__init__()
282
+ self.config = config
283
+ self.activation_fn = ACT2FN[config.hidden_act]
284
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
285
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
286
+
287
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
288
+ hidden_states = self.fc1(hidden_states)
289
+ hidden_states = self.activation_fn(hidden_states)
290
+ hidden_states = self.fc2(hidden_states)
291
+ return hidden_states
292
+
293
+
294
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
295
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
296
+ """
297
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
298
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
299
+ """
300
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
301
+ if n_rep == 1:
302
+ return hidden_states
303
+ hidden_states = hidden_states[:, :, None, :, :].expand(
304
+ batch, num_key_value_heads, n_rep, slen, head_dim
305
+ )
306
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
307
+
308
+
309
+ class PhiAttention(nn.Module):
310
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
311
+
312
+ def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
313
+ super().__init__()
314
+ self.config = config
315
+ self.layer_idx = layer_idx
316
+ if layer_idx is None:
317
+ logger.warning_once(
318
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
319
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
320
+ "when creating this class."
321
+ )
322
+
323
+ self.attention_dropout = config.attention_dropout
324
+ self.hidden_size = config.hidden_size
325
+ self.num_heads = config.num_attention_heads
326
+ self.head_dim = self.hidden_size // self.num_heads
327
+ self.num_key_value_heads = config.num_key_value_heads
328
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
329
+ self.max_position_embeddings = config.max_position_embeddings
330
+ self.rope_theta = config.rope_theta
331
+ self.partial_rotary_factor = config.partial_rotary_factor
332
+ self.is_causal = True
333
+
334
+ if (self.head_dim * self.num_heads) != self.hidden_size:
335
+ raise ValueError(
336
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
337
+ f" and `num_heads`: {self.num_heads})."
338
+ )
339
+
340
+ self.Wqkv = nn.Linear(
341
+ self.hidden_size, 3 * self.num_heads * self.head_dim, bias=True
342
+ )
343
+ self.out_proj = nn.Linear(
344
+ self.num_heads * self.head_dim, self.hidden_size, bias=True
345
+ )
346
+
347
+ self._init_rope()
348
+
349
+ def _init_rope(self):
350
+ if self.config.rope_scaling is None:
351
+ self.rotary_emb = PhiRotaryEmbedding(
352
+ int(self.partial_rotary_factor * self.head_dim),
353
+ max_position_embeddings=self.max_position_embeddings,
354
+ base=self.rope_theta,
355
+ )
356
+ else:
357
+ scaling_type = self.config.rope_scaling["type"]
358
+ scaling_factor = self.config.rope_scaling["factor"]
359
+ if scaling_type == "linear":
360
+ self.rotary_emb = PhiLinearScalingRotaryEmbedding(
361
+ int(self.partial_rotary_factor * self.head_dim),
362
+ max_position_embeddings=self.max_position_embeddings,
363
+ scaling_factor=scaling_factor,
364
+ base=self.rope_theta,
365
+ )
366
+ elif scaling_type == "dynamic":
367
+ self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
368
+ int(self.partial_rotary_factor * self.head_dim),
369
+ max_position_embeddings=self.max_position_embeddings,
370
+ scaling_factor=scaling_factor,
371
+ base=self.rope_theta,
372
+ )
373
+ else:
374
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
375
+
376
+ def forward(
377
+ self,
378
+ hidden_states: torch.Tensor,
379
+ attention_mask: Optional[torch.Tensor] = None,
380
+ position_ids: Optional[torch.LongTensor] = None,
381
+ past_key_value: Optional[Cache] = None,
382
+ output_attentions: bool = False,
383
+ use_cache: bool = False,
384
+ cache_position: Optional[torch.LongTensor] = None,
385
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
386
+ bsz, q_len, _ = hidden_states.size()
387
+
388
+ query_states, key_states, value_states = self.Wqkv(hidden_states).chunk(
389
+ 3, dim=-1
390
+ )
391
+
392
+ query_states = query_states.view(
393
+ bsz, q_len, self.num_heads, self.head_dim
394
+ ).transpose(1, 2)
395
+ key_states = key_states.view(
396
+ bsz, q_len, self.num_key_value_heads, self.head_dim
397
+ ).transpose(1, 2)
398
+ value_states = value_states.view(
399
+ bsz, q_len, self.num_key_value_heads, self.head_dim
400
+ ).transpose(1, 2)
401
+
402
+ kv_seq_len = key_states.shape[-2]
403
+ if past_key_value is not None:
404
+ if self.layer_idx is None:
405
+ raise ValueError(
406
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
407
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
408
+ "with a layer index."
409
+ )
410
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
411
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
412
+
413
+ # Partial rotary embedding
414
+ query_rot, query_pass = (
415
+ query_states[..., : self.rotary_emb.dim],
416
+ query_states[..., self.rotary_emb.dim :],
417
+ )
418
+ key_rot, key_pass = (
419
+ key_states[..., : self.rotary_emb.dim],
420
+ key_states[..., self.rotary_emb.dim :],
421
+ )
422
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
423
+ query_rot, key_rot = apply_rotary_pos_emb(
424
+ query_rot, key_rot, cos, sin, position_ids
425
+ )
426
+
427
+ # [batch_size, seq_length, num_heads, head_dim]
428
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
429
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
430
+
431
+ if past_key_value is not None:
432
+ cache_kwargs = {
433
+ "sin": sin,
434
+ "cos": cos,
435
+ "partial_rotation_size": self.rotary_emb.dim,
436
+ "cache_position": cache_position,
437
+ }
438
+ key_states, value_states = past_key_value.update(
439
+ key_states, value_states, self.layer_idx, cache_kwargs
440
+ )
441
+
442
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
443
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
444
+
445
+ # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
446
+ attn_weights = torch.matmul(
447
+ query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
448
+ ) / math.sqrt(self.head_dim)
449
+
450
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
451
+ raise ValueError(
452
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
453
+ f" {attn_weights.size()}"
454
+ )
455
+
456
+ if attention_mask is not None:
457
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
458
+ attn_weights += causal_mask
459
+
460
+ # upcast attention to fp32
461
+ attn_weights = nn.functional.softmax(
462
+ attn_weights, dim=-1, dtype=torch.float32
463
+ ).to(value_states.dtype)
464
+ attn_weights = nn.functional.dropout(
465
+ attn_weights, p=self.attention_dropout, training=self.training
466
+ )
467
+
468
+ attn_output = torch.matmul(attn_weights, value_states)
469
+
470
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
471
+ raise ValueError(
472
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
473
+ f" {attn_output.size()}"
474
+ )
475
+
476
+ attn_output = attn_output.transpose(1, 2).contiguous()
477
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
478
+
479
+ attn_output = self.out_proj(attn_output)
480
+
481
+ if not output_attentions:
482
+ attn_weights = None
483
+
484
+ return attn_output, attn_weights, past_key_value
485
+
486
+
487
+ class PhiFlashAttention2(PhiAttention):
488
+ """
489
+ Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
490
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
491
+ flash attention and deal with padding tokens in case the input contains any of them.
492
+ """
493
+
494
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
495
+ def __init__(self, *args, **kwargs):
496
+ super().__init__(*args, **kwargs)
497
+
498
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
499
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
500
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
501
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
502
+
503
+ def forward(
504
+ self,
505
+ hidden_states: torch.Tensor,
506
+ attention_mask: Optional[torch.LongTensor] = None,
507
+ position_ids: Optional[torch.LongTensor] = None,
508
+ past_key_value: Optional[Cache] = None,
509
+ output_attentions: bool = False,
510
+ use_cache: bool = False,
511
+ cache_position: Optional[torch.LongTensor] = None,
512
+ **kwargs,
513
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
514
+ # PhiFlashAttention2 attention does not support output_attentions
515
+
516
+ output_attentions = False
517
+
518
+ bsz, q_len, _ = hidden_states.size()
519
+
520
+ query_states, key_states, value_states = self.Wqkv(hidden_states).chunk(
521
+ 3, dim=-1
522
+ )
523
+
524
+ # Flash attention requires the input to have the shape
525
+ # batch_size x seq_length x head_dim x hidden_dim
526
+ # therefore we just need to keep the original shape
527
+ query_states = query_states.view(
528
+ bsz, q_len, self.num_heads, self.head_dim
529
+ ).transpose(1, 2)
530
+ key_states = key_states.view(
531
+ bsz, q_len, self.num_key_value_heads, self.head_dim
532
+ ).transpose(1, 2)
533
+ value_states = value_states.view(
534
+ bsz, q_len, self.num_key_value_heads, self.head_dim
535
+ ).transpose(1, 2)
536
+
537
+ kv_seq_len = key_states.shape[-2]
538
+ if past_key_value is not None:
539
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
540
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
541
+
542
+ # Partial rotary embedding
543
+ query_rot, query_pass = (
544
+ query_states[..., : self.rotary_emb.dim],
545
+ query_states[..., self.rotary_emb.dim :],
546
+ )
547
+ key_rot, key_pass = (
548
+ key_states[..., : self.rotary_emb.dim],
549
+ key_states[..., self.rotary_emb.dim :],
550
+ )
551
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
552
+ query_rot, key_rot = apply_rotary_pos_emb(
553
+ query_rot, key_rot, cos, sin, position_ids
554
+ )
555
+
556
+ # [batch_size, seq_length, num_heads, head_dim]
557
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
558
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
559
+
560
+ if past_key_value is not None:
561
+ cache_kwargs = {
562
+ "sin": sin,
563
+ "cos": cos,
564
+ "partial_rotation_size": self.rotary_emb.dim,
565
+ "cache_position": cache_position,
566
+ }
567
+ key_states, value_states = past_key_value.update(
568
+ key_states, value_states, self.layer_idx, cache_kwargs
569
+ )
570
+
571
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
572
+ # to be able to avoid many of these transpose/reshape/view.
573
+ query_states = query_states.transpose(1, 2)
574
+ key_states = key_states.transpose(1, 2)
575
+ value_states = value_states.transpose(1, 2)
576
+
577
+ attn_dropout = self.attention_dropout if self.training else 0.0
578
+
579
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
580
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
581
+ # cast them back in the correct dtype just to be sure everything works as expected.
582
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
583
+ # in fp32.
584
+
585
+ if query_states.dtype == torch.float32:
586
+ if torch.is_autocast_enabled():
587
+ target_dtype = torch.get_autocast_gpu_dtype()
588
+ # Handle the case where the model is quantized
589
+ elif hasattr(self.config, "_pre_quantization_dtype"):
590
+ target_dtype = self.config._pre_quantization_dtype
591
+ else:
592
+ target_dtype = self.q_proj.weight.dtype
593
+
594
+ logger.warning_once(
595
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
596
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
597
+ f" {target_dtype}."
598
+ )
599
+
600
+ query_states = query_states.to(target_dtype)
601
+ key_states = key_states.to(target_dtype)
602
+ value_states = value_states.to(target_dtype)
603
+
604
+ attn_output = _flash_attention_forward(
605
+ query_states,
606
+ key_states,
607
+ value_states,
608
+ attention_mask,
609
+ q_len,
610
+ position_ids=position_ids,
611
+ dropout=attn_dropout,
612
+ softmax_scale=None,
613
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
614
+ is_causal=self.is_causal,
615
+ )
616
+
617
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
618
+ attn_output = self.out_proj(attn_output)
619
+
620
+ if not output_attentions:
621
+ attn_weights = None
622
+
623
+ return attn_output, attn_weights, past_key_value
624
+
625
+
626
+ class PhiSdpaAttention(PhiAttention):
627
+ def __init__(self, *args, **kwargs):
628
+ super().__init__(*args, **kwargs)
629
+ self.require_contiguous_qkv = version.parse(
630
+ get_torch_version()
631
+ ) < version.parse("2.2.0")
632
+
633
+ """
634
+ SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
635
+ `PhiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
636
+ SDPA API.
637
+ """
638
+
639
+ # Adapted from PhiAttention.forward
640
+ def forward(
641
+ self,
642
+ hidden_states: torch.Tensor,
643
+ attention_mask: Optional[torch.Tensor] = None,
644
+ position_ids: Optional[torch.LongTensor] = None,
645
+ past_key_value: Optional[Cache] = None,
646
+ output_attentions: bool = False,
647
+ use_cache: bool = False,
648
+ cache_position: Optional[torch.LongTensor] = None,
649
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
650
+ if output_attentions:
651
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
652
+ logger.warning_once(
653
+ "PhiModel is using PhiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
654
+ "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
655
+ "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
656
+ 'be removed using the argument `attn_implementation="eager"` when loading the model.'
657
+ )
658
+ return super().forward(
659
+ hidden_states=hidden_states,
660
+ attention_mask=attention_mask,
661
+ position_ids=position_ids,
662
+ past_key_value=past_key_value,
663
+ output_attentions=output_attentions,
664
+ use_cache=use_cache,
665
+ )
666
+
667
+ bsz, q_len, _ = hidden_states.size()
668
+
669
+ query_states, key_states, value_states = self.Wqkv(hidden_states).chunk(
670
+ 3, dim=-1
671
+ )
672
+
673
+ query_states = query_states.view(
674
+ bsz, q_len, self.num_heads, self.head_dim
675
+ ).transpose(1, 2)
676
+ key_states = key_states.view(
677
+ bsz, q_len, self.num_key_value_heads, self.head_dim
678
+ ).transpose(1, 2)
679
+ value_states = value_states.view(
680
+ bsz, q_len, self.num_key_value_heads, self.head_dim
681
+ ).transpose(1, 2)
682
+
683
+ kv_seq_len = key_states.shape[-2]
684
+ if past_key_value is not None:
685
+ if self.layer_idx is None:
686
+ raise ValueError(
687
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
688
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
689
+ "with a layer index."
690
+ )
691
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
692
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
693
+
694
+ # Partial rotary embedding
695
+ query_rot, query_pass = (
696
+ query_states[..., : self.rotary_emb.dim],
697
+ query_states[..., self.rotary_emb.dim :],
698
+ )
699
+ key_rot, key_pass = (
700
+ key_states[..., : self.rotary_emb.dim],
701
+ key_states[..., self.rotary_emb.dim :],
702
+ )
703
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
704
+ query_rot, key_rot = apply_rotary_pos_emb(
705
+ query_rot, key_rot, cos, sin, position_ids
706
+ )
707
+
708
+ # [batch_size, seq_length, num_heads, head_dim]
709
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
710
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
711
+
712
+ if past_key_value is not None:
713
+ cache_kwargs = {
714
+ "sin": sin,
715
+ "cos": cos,
716
+ "partial_rotation_size": self.rotary_emb.dim,
717
+ "cache_position": cache_position,
718
+ }
719
+ key_states, value_states = past_key_value.update(
720
+ key_states, value_states, self.layer_idx, cache_kwargs
721
+ )
722
+
723
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
724
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
725
+
726
+ causal_mask = attention_mask
727
+ if attention_mask is not None:
728
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
729
+
730
+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
731
+ # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
732
+ # Reference: https://github.com/pytorch/pytorch/issues/112577
733
+ if (
734
+ self.require_contiguous_qkv
735
+ and query_states.device.type == "cuda"
736
+ and attention_mask is not None
737
+ ):
738
+ query_states = query_states.contiguous()
739
+ key_states = key_states.contiguous()
740
+ value_states = value_states.contiguous()
741
+
742
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
743
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
744
+ is_causal = True if causal_mask is None and q_len > 1 else False
745
+
746
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
747
+ query_states,
748
+ key_states,
749
+ value_states,
750
+ attn_mask=causal_mask,
751
+ dropout_p=self.attention_dropout if self.training else 0.0,
752
+ is_causal=is_causal,
753
+ )
754
+
755
+ attn_output = attn_output.transpose(1, 2).contiguous()
756
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
757
+
758
+ attn_output = self.out_proj(attn_output)
759
+
760
+ return attn_output, None, past_key_value
761
+
762
+
763
+ PHI_ATTENTION_CLASSES = {
764
+ "eager": PhiAttention,
765
+ "flash_attention_2": PhiFlashAttention2,
766
+ "sdpa": PhiSdpaAttention,
767
+ }
768
+
769
+
770
+ class PhiDecoderLayer(nn.Module):
771
+ def __init__(self, config: PhiConfig, layer_idx: int):
772
+ super().__init__()
773
+ self.mixer = PHI_ATTENTION_CLASSES[config._attn_implementation](
774
+ config, layer_idx=layer_idx
775
+ )
776
+ self.mlp = PhiMLP(config)
777
+ self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
778
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
779
+
780
+ def forward(
781
+ self,
782
+ hidden_states: torch.Tensor,
783
+ attention_mask: Optional[torch.Tensor] = None,
784
+ position_ids: Optional[torch.LongTensor] = None,
785
+ output_attentions: Optional[bool] = False,
786
+ use_cache: Optional[bool] = False,
787
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
788
+ cache_position: Optional[torch.LongTensor] = None,
789
+ **kwargs,
790
+ ) -> Tuple[
791
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
792
+ ]:
793
+ """
794
+ Args:
795
+ hidden_states (`torch.FloatTensor`):
796
+ input to the layer of shape `(batch, seq_len, embed_dim)`
797
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
798
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
799
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
800
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
801
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
802
+ output_attentions (`bool`, *optional*):
803
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
804
+ returned tensors for more detail.
805
+ use_cache (`bool`, *optional*):
806
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
807
+ (see `past_key_values`).
808
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
809
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
810
+ Indices depicting the position of the input sequence tokens in the sequence
811
+ kwargs (`dict`, *optional*):
812
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
813
+ into the model
814
+ """
815
+
816
+ residual = hidden_states
817
+
818
+ hidden_states = self.ln(hidden_states)
819
+
820
+ # Self Attention
821
+ attn_outputs, self_attn_weights, present_key_value = self.mixer(
822
+ hidden_states=hidden_states,
823
+ attention_mask=attention_mask,
824
+ position_ids=position_ids,
825
+ past_key_value=past_key_value,
826
+ output_attentions=output_attentions,
827
+ use_cache=use_cache,
828
+ cache_position=cache_position,
829
+ )
830
+ attn_outputs = self.resid_dropout(attn_outputs)
831
+
832
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
833
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
834
+ outputs = (hidden_states,)
835
+
836
+ if output_attentions:
837
+ outputs += (self_attn_weights,)
838
+
839
+ if use_cache:
840
+ outputs += (present_key_value,)
841
+
842
+ return outputs
843
+
844
+
845
+ PHI_START_DOCSTRING = r"""
846
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
847
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
848
+ etc.)
849
+
850
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
851
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
852
+ and behavior.
853
+
854
+ Parameters:
855
+ config ([`PhiConfig`]):
856
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
857
+ load the weights associated with the model, only the configuration. Check out the
858
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
859
+ """
860
+
861
+
862
+ @add_start_docstrings(
863
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
864
+ PHI_START_DOCSTRING,
865
+ )
866
+ class PhiPreTrainedModel(PreTrainedModel):
867
+ config_class = PhiConfig
868
+ base_model_prefix = "model"
869
+ supports_gradient_checkpointing = True
870
+ _no_split_modules = ["PhiDecoderLayer"]
871
+ _skip_keys_device_placement = "past_key_values"
872
+ _supports_flash_attn_2 = True
873
+ _supports_sdpa = True
874
+ _supports_cache_class = True
875
+
876
+ def _init_weights(self, module):
877
+ std = self.config.initializer_range
878
+ if isinstance(module, nn.Linear):
879
+ module.weight.data.normal_(mean=0.0, std=std)
880
+ if module.bias is not None:
881
+ module.bias.data.zero_()
882
+ elif isinstance(module, nn.Embedding):
883
+ module.weight.data.normal_(mean=0.0, std=std)
884
+ if module.padding_idx is not None:
885
+ module.weight.data[module.padding_idx].zero_()
886
+
887
+
888
+ class Embedding(nn.Module):
889
+ def __init__(self, config: PhiConfig):
890
+ super().__init__()
891
+ self.wte = nn.Embedding(
892
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
893
+ )
894
+
895
+ def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
896
+ return self.wte(input_ids)
897
+
898
+ PHI_INPUTS_DOCSTRING = r"""
899
+ Args:
900
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
901
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
902
+ it.
903
+
904
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
905
+ [`PreTrainedTokenizer.__call__`] for details.
906
+
907
+ [What are input IDs?](../glossary#input-ids)
908
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
909
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
910
+
911
+ - 1 for tokens that are **not masked**,
912
+ - 0 for tokens that are **masked**.
913
+
914
+ [What are attention masks?](../glossary#attention-mask)
915
+
916
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
917
+ [`PreTrainedTokenizer.__call__`] for details.
918
+
919
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
920
+ `past_key_values`).
921
+
922
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
923
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
924
+ information on the default strategy.
925
+
926
+ - 1 indicates the head is **not masked**,
927
+ - 0 indicates the head is **masked**.
928
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
929
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
930
+ config.n_positions - 1]`.
931
+
932
+ [What are position IDs?](../glossary#position-ids)
933
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
934
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
935
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
936
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
937
+
938
+ Two formats are allowed:
939
+ - a [`~cache_utils.Cache`] instance;
940
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
941
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
942
+ cache format.
943
+
944
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
945
+ legacy cache format will be returned.
946
+
947
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
948
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
949
+ of shape `(batch_size, sequence_length)`.
950
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
951
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
952
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
953
+ model's internal embedding lookup matrix.
954
+ use_cache (`bool`, *optional*):
955
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
956
+ `past_key_values`).
957
+ output_attentions (`bool`, *optional*):
958
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
959
+ tensors for more detail.
960
+ output_hidden_states (`bool`, *optional*):
961
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
962
+ more detail.
963
+ return_dict (`bool`, *optional*):
964
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
965
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
966
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
967
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
968
+ the complete sequence length.
969
+ """
970
+
971
+
972
+ @add_start_docstrings(
973
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
974
+ PHI_START_DOCSTRING,
975
+ )
976
+ class PhiModel(PhiPreTrainedModel):
977
+ """
978
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
979
+
980
+ Args:
981
+ config: PhiConfig
982
+ """
983
+
984
+ def __init__(self, config: PhiConfig):
985
+ super().__init__(config)
986
+ self.padding_idx = config.pad_token_id
987
+ self.vocab_size = config.vocab_size
988
+
989
+ self.embd = Embedding(config)
990
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
991
+ self.h = nn.ModuleList(
992
+ [
993
+ PhiDecoderLayer(config, layer_idx)
994
+ for layer_idx in range(config.num_hidden_layers)
995
+ ]
996
+ )
997
+
998
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
999
+ self._use_sdpa = config._attn_implementation == "sdpa"
1000
+
1001
+ self.gradient_checkpointing = False
1002
+ # Initialize weights and apply final processing
1003
+ self.post_init()
1004
+
1005
+ def get_input_embeddings(self):
1006
+ return self.embd.wte
1007
+
1008
+ def set_input_embeddings(self, value):
1009
+ self.embd.wte = value
1010
+
1011
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1012
+ def forward(
1013
+ self,
1014
+ input_ids: torch.LongTensor = None,
1015
+ attention_mask: Optional[torch.Tensor] = None,
1016
+ position_ids: Optional[torch.LongTensor] = None,
1017
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1018
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1019
+ use_cache: Optional[bool] = None,
1020
+ output_attentions: Optional[bool] = None,
1021
+ output_hidden_states: Optional[bool] = None,
1022
+ return_dict: Optional[bool] = None,
1023
+ cache_position: Optional[torch.LongTensor] = None,
1024
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1025
+ output_attentions = (
1026
+ output_attentions
1027
+ if output_attentions is not None
1028
+ else self.config.output_attentions
1029
+ )
1030
+ output_hidden_states = (
1031
+ output_hidden_states
1032
+ if output_hidden_states is not None
1033
+ else self.config.output_hidden_states
1034
+ )
1035
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1036
+
1037
+ return_dict = (
1038
+ return_dict if return_dict is not None else self.config.use_return_dict
1039
+ )
1040
+
1041
+ if (input_ids is None) ^ (inputs_embeds is not None):
1042
+ raise ValueError(
1043
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1044
+ )
1045
+
1046
+ if self.gradient_checkpointing and self.training:
1047
+ if use_cache:
1048
+ logger.warning_once(
1049
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1050
+ )
1051
+ use_cache = False
1052
+
1053
+ use_legacy_cache = False
1054
+ if use_cache and not isinstance(past_key_values, Cache) and not self.training:
1055
+ use_legacy_cache = True
1056
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1057
+ logger.warning_once(
1058
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
1059
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
1060
+ )
1061
+
1062
+ if inputs_embeds is None:
1063
+ inputs_embeds = self.embd(input_ids)
1064
+
1065
+ if cache_position is None:
1066
+ past_seen_tokens = (
1067
+ past_key_values.get_seq_length() if past_key_values is not None else 0
1068
+ )
1069
+ cache_position = torch.arange(
1070
+ past_seen_tokens,
1071
+ past_seen_tokens + inputs_embeds.shape[1],
1072
+ device=inputs_embeds.device,
1073
+ )
1074
+ if position_ids is None:
1075
+ position_ids = cache_position.unsqueeze(0)
1076
+
1077
+ causal_mask = self._update_causal_mask(
1078
+ attention_mask,
1079
+ inputs_embeds,
1080
+ cache_position,
1081
+ past_key_values,
1082
+ output_attentions,
1083
+ )
1084
+
1085
+ hidden_states = inputs_embeds
1086
+
1087
+ # decoder layers
1088
+ all_hidden_states = () if output_hidden_states else None
1089
+ all_self_attns = () if output_attentions else None
1090
+ next_decoder_cache = None
1091
+
1092
+ for decoder_layer in self.h:
1093
+ if output_hidden_states:
1094
+ all_hidden_states += (hidden_states,)
1095
+
1096
+ if self.gradient_checkpointing and self.training:
1097
+ layer_outputs = self._gradient_checkpointing_func(
1098
+ decoder_layer.__call__,
1099
+ hidden_states,
1100
+ causal_mask,
1101
+ position_ids,
1102
+ output_attentions,
1103
+ use_cache,
1104
+ past_key_values,
1105
+ cache_position,
1106
+ )
1107
+ else:
1108
+ layer_outputs = decoder_layer(
1109
+ hidden_states,
1110
+ attention_mask=causal_mask,
1111
+ position_ids=position_ids,
1112
+ past_key_value=past_key_values,
1113
+ output_attentions=output_attentions,
1114
+ use_cache=use_cache,
1115
+ cache_position=cache_position,
1116
+ )
1117
+
1118
+ hidden_states = layer_outputs[0]
1119
+
1120
+ if use_cache:
1121
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1122
+
1123
+ if output_attentions:
1124
+ all_self_attns += (layer_outputs[1],)
1125
+
1126
+ # add hidden states from the last decoder layer
1127
+ if output_hidden_states:
1128
+ all_hidden_states += (hidden_states,)
1129
+
1130
+ next_cache = None
1131
+ if use_cache:
1132
+ next_cache = (
1133
+ next_decoder_cache.to_legacy_cache()
1134
+ if use_legacy_cache
1135
+ else next_decoder_cache
1136
+ )
1137
+ if not return_dict:
1138
+ return tuple(
1139
+ v
1140
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1141
+ if v is not None
1142
+ )
1143
+ return BaseModelOutputWithPast(
1144
+ last_hidden_state=hidden_states,
1145
+ past_key_values=next_cache,
1146
+ hidden_states=all_hidden_states,
1147
+ attentions=all_self_attns,
1148
+ )
1149
+
1150
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
1151
+ def _update_causal_mask(
1152
+ self,
1153
+ attention_mask: torch.Tensor,
1154
+ input_tensor: torch.Tensor,
1155
+ cache_position: torch.Tensor,
1156
+ past_key_values: Cache,
1157
+ output_attentions: bool,
1158
+ ):
1159
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1160
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1161
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1162
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1163
+
1164
+ if self.config._attn_implementation == "flash_attention_2":
1165
+ if attention_mask is not None and 0.0 in attention_mask:
1166
+ return attention_mask
1167
+ return None
1168
+
1169
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1170
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1171
+ # to infer the attention mask.
1172
+ past_seen_tokens = (
1173
+ past_key_values.get_seq_length() if past_key_values is not None else 0
1174
+ )
1175
+ using_static_cache = isinstance(past_key_values, StaticCache)
1176
+
1177
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1178
+ if (
1179
+ self.config._attn_implementation == "sdpa"
1180
+ and not using_static_cache
1181
+ and not output_attentions
1182
+ ):
1183
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1184
+ attention_mask,
1185
+ inputs_embeds=input_tensor,
1186
+ past_key_values_length=past_seen_tokens,
1187
+ is_training=self.training,
1188
+ ):
1189
+ return None
1190
+
1191
+ dtype, device = input_tensor.dtype, input_tensor.device
1192
+ min_dtype = torch.finfo(dtype).min
1193
+ sequence_length = input_tensor.shape[1]
1194
+ if using_static_cache:
1195
+ target_length = past_key_values.get_max_length()
1196
+ else:
1197
+ target_length = (
1198
+ attention_mask.shape[-1]
1199
+ if isinstance(attention_mask, torch.Tensor)
1200
+ else past_seen_tokens + sequence_length + 1
1201
+ )
1202
+
1203
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1204
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1205
+ attention_mask,
1206
+ sequence_length=sequence_length,
1207
+ target_length=target_length,
1208
+ dtype=dtype,
1209
+ device=device,
1210
+ min_dtype=min_dtype,
1211
+ cache_position=cache_position,
1212
+ batch_size=input_tensor.shape[0],
1213
+ )
1214
+
1215
+ if (
1216
+ self.config._attn_implementation == "sdpa"
1217
+ and attention_mask is not None
1218
+ and attention_mask.device.type == "cuda"
1219
+ and not output_attentions
1220
+ ):
1221
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1222
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1223
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1224
+ causal_mask = AttentionMaskConverter._unmask_unattended(
1225
+ causal_mask, min_dtype
1226
+ )
1227
+
1228
+ return causal_mask
1229
+
1230
+
1231
+ class CausalLMHead(nn.Module):
1232
+ """Causal Language Modeling head. Simplified version."""
1233
+
1234
+ def __init__(self, config):
1235
+ super().__init__()
1236
+ self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1237
+ self.linear = nn.Linear(config.hidden_size, config.vocab_size)
1238
+
1239
+ def forward(self, hidden_states):
1240
+ return self.linear(self.ln(hidden_states))
1241
+
1242
+
1243
+ class PhiForCausalLM(PhiPreTrainedModel):
1244
+
1245
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True
1246
+ def __init__(self, config):
1247
+ super().__init__(config)
1248
+ self.transformer = PhiModel(config)
1249
+ self.vocab_size = config.vocab_size
1250
+ self.lm_head = CausalLMHead(config)
1251
+
1252
+ # Initialize weights and apply final processing
1253
+ self.post_init()
1254
+
1255
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
1256
+ def get_input_embeddings(self):
1257
+ return self.transformer.embd.wte
1258
+
1259
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
1260
+ def set_input_embeddings(self, value):
1261
+ self.transformer.embd.wte = value
1262
+
1263
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
1264
+ def get_output_embeddings(self):
1265
+ return self.lm_head.linear
1266
+
1267
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
1268
+ def set_output_embeddings(self, new_embeddings):
1269
+ self.lm_head.linear = new_embeddings
1270
+
1271
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
1272
+ def set_decoder(self, decoder):
1273
+ self.model = decoder
1274
+
1275
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
1276
+ def get_decoder(self):
1277
+ return self.model
1278
+
1279
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1280
+ @replace_return_docstrings(
1281
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1282
+ )
1283
+ def forward(
1284
+ self,
1285
+ input_ids: torch.LongTensor = None,
1286
+ attention_mask: Optional[torch.Tensor] = None,
1287
+ position_ids: Optional[torch.LongTensor] = None,
1288
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1289
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1290
+ labels: Optional[torch.LongTensor] = None,
1291
+ use_cache: Optional[bool] = None,
1292
+ output_attentions: Optional[bool] = None,
1293
+ output_hidden_states: Optional[bool] = None,
1294
+ return_dict: Optional[bool] = None,
1295
+ cache_position: Optional[torch.LongTensor] = None,
1296
+ num_logits_to_keep: int = 0,
1297
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1298
+ r"""
1299
+ Args:
1300
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1301
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1302
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1303
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1304
+
1305
+ num_logits_to_keep (`int`, *optional*):
1306
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1307
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1308
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1309
+
1310
+ Returns:
1311
+
1312
+ Example:
1313
+
1314
+ ```python
1315
+ >>> from transformers import AutoTokenizer, PhiForCausalLM
1316
+
1317
+ >>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1")
1318
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
1319
+
1320
+ >>> prompt = "This is an example script ."
1321
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1322
+
1323
+ >>> # Generate
1324
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1325
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1326
+ 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
1327
+ ```"""
1328
+
1329
+ output_attentions = (
1330
+ output_attentions
1331
+ if output_attentions is not None
1332
+ else self.config.output_attentions
1333
+ )
1334
+ output_hidden_states = (
1335
+ output_hidden_states
1336
+ if output_hidden_states is not None
1337
+ else self.config.output_hidden_states
1338
+ )
1339
+ return_dict = (
1340
+ return_dict if return_dict is not None else self.config.use_return_dict
1341
+ )
1342
+
1343
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1344
+ outputs = self.transformer(
1345
+ input_ids=input_ids,
1346
+ attention_mask=attention_mask,
1347
+ position_ids=position_ids,
1348
+ past_key_values=past_key_values,
1349
+ inputs_embeds=inputs_embeds,
1350
+ use_cache=use_cache,
1351
+ output_attentions=output_attentions,
1352
+ output_hidden_states=output_hidden_states,
1353
+ return_dict=return_dict,
1354
+ cache_position=cache_position,
1355
+ )
1356
+
1357
+ hidden_states = outputs[0]
1358
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
1359
+
1360
+ loss = None
1361
+ if labels is not None:
1362
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
1363
+ logits = logits.float()
1364
+ # Shift so that tokens < n predict n
1365
+ shift_logits = logits[..., :-1, :].contiguous()
1366
+ shift_labels = labels[..., 1:].contiguous()
1367
+ # Flatten the tokens
1368
+ loss_fct = CrossEntropyLoss()
1369
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1370
+ shift_labels = shift_labels.view(-1)
1371
+ # Enable model parallelism
1372
+ shift_labels = shift_labels.to(shift_logits.device)
1373
+ loss = loss_fct(shift_logits, shift_labels)
1374
+
1375
+ if not return_dict:
1376
+ output = (logits,) + outputs[1:]
1377
+ return (loss,) + output if loss is not None else output
1378
+
1379
+ return CausalLMOutputWithPast(
1380
+ loss=loss,
1381
+ logits=logits,
1382
+ past_key_values=outputs.past_key_values,
1383
+ hidden_states=outputs.hidden_states,
1384
+ attentions=outputs.attentions,
1385
+ )
1386
+
1387
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
1388
+ def prepare_inputs_for_generation(
1389
+ self,
1390
+ input_ids,
1391
+ past_key_values=None,
1392
+ attention_mask=None,
1393
+ inputs_embeds=None,
1394
+ cache_position=None,
1395
+ position_ids=None,
1396
+ use_cache=True,
1397
+ num_logits_to_keep=0,
1398
+ **kwargs,
1399
+ ):
1400
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1401
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
1402
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
1403
+ if past_key_values is not None:
1404
+ if inputs_embeds is not None: # Exception 1
1405
+ input_ids = input_ids[:, -cache_position.shape[0] :]
1406
+ elif (
1407
+ input_ids.shape[1] != cache_position.shape[0]
1408
+ ): # Default case (the "else", a no op, is Exception 2)
1409
+ input_ids = input_ids[:, cache_position]
1410
+
1411
+ if attention_mask is not None and position_ids is None:
1412
+ # create position_ids on the fly for batch generation
1413
+ position_ids = attention_mask.long().cumsum(-1) - 1
1414
+ position_ids.masked_fill_(attention_mask == 0, 1)
1415
+ if past_key_values:
1416
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1417
+
1418
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
1419
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
1420
+
1421
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1422
+ if inputs_embeds is not None and cache_position[0] == 0:
1423
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
1424
+ else:
1425
+ # The clone here is for the same reason as for `position_ids`.
1426
+ model_inputs = {
1427
+ "input_ids": input_ids.clone(memory_format=torch.contiguous_format),
1428
+ "inputs_embeds": None,
1429
+ }
1430
+
1431
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
1432
+ if model_inputs["inputs_embeds"] is not None:
1433
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
1434
+ device = model_inputs["inputs_embeds"].device
1435
+ else:
1436
+ batch_size, sequence_length = model_inputs["input_ids"].shape
1437
+ device = model_inputs["input_ids"].device
1438
+
1439
+ dtype = self.lm_head.weight.dtype
1440
+ min_dtype = torch.finfo(dtype).min
1441
+
1442
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1443
+ attention_mask,
1444
+ sequence_length=sequence_length,
1445
+ target_length=past_key_values.get_max_length(),
1446
+ dtype=dtype,
1447
+ device=device,
1448
+ min_dtype=min_dtype,
1449
+ cache_position=cache_position,
1450
+ batch_size=batch_size,
1451
+ )
1452
+
1453
+ model_inputs.update(
1454
+ {
1455
+ "position_ids": position_ids,
1456
+ "cache_position": cache_position,
1457
+ "past_key_values": past_key_values,
1458
+ "use_cache": use_cache,
1459
+ "attention_mask": attention_mask,
1460
+ "num_logits_to_keep": num_logits_to_keep,
1461
+ }
1462
+ )
1463
+ return model_inputs