guanwenyu1995 commited on
Commit
f9007d7
·
verified ·
1 Parent(s): 8b7525a

Upload modeling_llama.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_llama.py +883 -211
modeling_llama.py CHANGED
@@ -17,19 +17,19 @@
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
20
- from typing import Callable, Optional, Union
 
21
 
22
  import torch
 
23
  import torch.utils.checkpoint
24
  from torch import nn
25
 
26
  from transformers.activations import ACT2FN
27
- from transformers.cache_utils import Cache, DynamicCache
28
  from transformers.generation import GenerationMixin
29
- from transformers.integrations import use_kernel_forward_from_hub
30
- from transformers.masking_utils import create_causal_mask
31
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
32
- from transformers.modeling_layers import GradientCheckpointingLayer
33
  from transformers.modeling_outputs import (
34
  BaseModelOutputWithPast,
35
  CausalLMOutputWithPast,
@@ -37,15 +37,25 @@ from transformers.modeling_outputs import (
37
  SequenceClassifierOutputWithPast,
38
  TokenClassifierOutput,
39
  )
40
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
41
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
42
- from transformers.processing_utils import Unpack
43
- from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging
 
 
 
 
 
 
 
44
  from .configuration_llama import LlamaConfig
45
 
46
 
47
  logger = logging.get_logger(__name__)
48
 
 
 
 
49
 
50
  def get_quantizer(quant_type="none", bit=4, group_size=128):
51
  if quant_type == "intsym":
@@ -140,7 +150,6 @@ class LinearQuantizer(nn.Linear):
140
  x = x + self.bias
141
  return x
142
 
143
- @use_kernel_forward_from_hub("RMSNorm")
144
  class LlamaRMSNorm(nn.Module):
145
  def __init__(self, hidden_size, eps=1e-6):
146
  """
@@ -161,40 +170,121 @@ class LlamaRMSNorm(nn.Module):
161
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
162
 
163
 
 
 
 
164
  class LlamaRotaryEmbedding(nn.Module):
165
- def __init__(self, config: LlamaConfig, device=None):
 
 
 
 
 
 
 
 
 
166
  super().__init__()
167
- # BC: "rope_type" was originally "type"
168
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
169
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  else:
171
- self.rope_type = "default"
172
- self.max_seq_len_cached = config.max_position_embeddings
173
- self.original_max_seq_len = config.max_position_embeddings
 
 
 
 
174
 
175
  self.config = config
176
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
177
 
178
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
179
  self.register_buffer("inv_freq", inv_freq, persistent=False)
180
  self.original_inv_freq = self.inv_freq
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  @torch.no_grad()
183
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
184
  def forward(self, x, position_ids):
185
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
186
- position_ids_expanded = position_ids[:, None, :].float()
187
 
188
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
189
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
 
 
 
 
 
190
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
191
  emb = torch.cat((freqs, freqs), dim=-1)
192
- cos = emb.cos() * self.attention_scaling
193
- sin = emb.sin() * self.attention_scaling
 
 
 
 
194
 
195
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
196
 
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  def rotate_half(x):
199
  """Rotates half the hidden dims of the input."""
200
  x1 = x[..., : x.shape[-1] // 2]
@@ -238,13 +328,28 @@ class LlamaMLP(nn.Module):
238
  self.gate_proj = LinearQuantizer(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, quant_type="ternary", bit=4, group_size=-1)
239
  self.up_proj = LinearQuantizer(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, quant_type="ternary", bit=4, group_size=-1)
240
  self.down_proj = LinearQuantizer(self.intermediate_size, self.hidden_size, bias=config.mlp_bias, quant_type="ternary", bit=4, group_size=-1)
241
- # self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
242
- # self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
243
- # self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
244
  self.act_fn = ACT2FN[config.hidden_act]
245
 
246
  def forward(self, x):
247
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  return down_proj
249
 
250
 
@@ -260,79 +365,88 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
260
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
261
 
262
 
263
- def eager_attention_forward(
264
- module: nn.Module,
265
- query: torch.Tensor,
266
- key: torch.Tensor,
267
- value: torch.Tensor,
268
- attention_mask: Optional[torch.Tensor],
269
- scaling: float,
270
- dropout: float = 0.0,
271
- **kwargs,
272
- ):
273
- key_states = repeat_kv(key, module.num_key_value_groups)
274
- value_states = repeat_kv(value, module.num_key_value_groups)
275
-
276
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
277
- if attention_mask is not None:
278
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
279
- attn_weights = attn_weights + causal_mask
280
-
281
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
282
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
283
- attn_output = torch.matmul(attn_weights, value_states)
284
- attn_output = attn_output.transpose(1, 2).contiguous()
285
-
286
- return attn_output, attn_weights
287
-
288
-
289
  class LlamaAttention(nn.Module):
290
  """Multi-headed attention from 'Attention Is All You Need' paper"""
291
 
292
- def __init__(self, config: LlamaConfig, layer_idx: int):
293
  super().__init__()
294
  self.config = config
295
  self.layer_idx = layer_idx
296
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
297
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
298
- self.scaling = self.head_dim**-0.5
 
 
 
 
299
  self.attention_dropout = config.attention_dropout
 
 
 
 
 
 
 
300
  self.is_causal = True
301
 
302
- # self.q_proj = nn.Linear(
303
- # config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
304
- # )
305
- # self.k_proj = nn.Linear(
306
- # config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
307
- # )
308
- # self.v_proj = nn.Linear(
309
- # config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
310
- # )
311
- # self.o_proj = nn.Linear(
312
- # config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
313
- # )
314
- self.q_proj = LinearQuantizer(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias, quant_type="ternary", bit=4, group_size=-1)
315
- self.k_proj = LinearQuantizer(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias, quant_type="ternary", bit=4, group_size=-1)
316
- self.v_proj = LinearQuantizer(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias, quant_type="ternary", bit=4, group_size=-1)
317
- self.o_proj = LinearQuantizer(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias, quant_type="ternary", bit=4, group_size=-1)
318
 
319
  def forward(
320
  self,
321
  hidden_states: torch.Tensor,
322
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
323
- attention_mask: Optional[torch.Tensor],
324
  past_key_value: Optional[Cache] = None,
 
 
325
  cache_position: Optional[torch.LongTensor] = None,
326
- **kwargs: Unpack[FlashAttentionKwargs],
327
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
328
- input_shape = hidden_states.shape[:-1]
329
- hidden_shape = (*input_shape, -1, self.head_dim)
330
 
331
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
332
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
333
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
 
 
 
 
334
 
335
- cos, sin = position_embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
337
 
338
  if past_key_value is not None:
@@ -340,32 +454,274 @@ class LlamaAttention(nn.Module):
340
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
341
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
342
 
343
- attention_interface: Callable = eager_attention_forward
344
- if self.config._attn_implementation != "eager":
345
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
- attn_output, attn_weights = attention_interface(
348
- self,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  query_states,
350
  key_states,
351
  value_states,
352
  attention_mask,
353
- dropout=0.0 if not self.training else self.attention_dropout,
354
- scaling=self.scaling,
355
- **kwargs,
 
 
 
356
  )
357
 
358
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
359
  attn_output = self.o_proj(attn_output)
360
- return attn_output, attn_weights
361
 
 
 
362
 
363
- class LlamaDecoderLayer(GradientCheckpointingLayer):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  def __init__(self, config: LlamaConfig, layer_idx: int):
365
  super().__init__()
366
  self.hidden_size = config.hidden_size
367
 
368
- self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
369
 
370
  self.mlp = LlamaMLP(config)
371
  self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -380,14 +736,37 @@ class LlamaDecoderLayer(GradientCheckpointingLayer):
380
  output_attentions: Optional[bool] = False,
381
  use_cache: Optional[bool] = False,
382
  cache_position: Optional[torch.LongTensor] = None,
383
- position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
384
- **kwargs: Unpack[FlashAttentionKwargs],
385
- ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  residual = hidden_states
 
387
  hidden_states = self.input_layernorm(hidden_states)
388
 
389
  # Self Attention
390
- hidden_states, self_attn_weights = self.self_attn(
391
  hidden_states=hidden_states,
392
  attention_mask=attention_mask,
393
  position_ids=position_ids,
@@ -407,27 +786,48 @@ class LlamaDecoderLayer(GradientCheckpointingLayer):
407
  hidden_states = residual + hidden_states
408
 
409
  outputs = (hidden_states,)
 
410
  if output_attentions:
411
  outputs += (self_attn_weights,)
412
 
 
 
 
413
  return outputs
414
 
415
 
416
- @auto_docstring
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  class LlamaPreTrainedModel(PreTrainedModel):
418
  config_class = LlamaConfig
419
  base_model_prefix = "model"
420
  supports_gradient_checkpointing = True
421
  _no_split_modules = ["LlamaDecoderLayer"]
422
  _skip_keys_device_placement = ["past_key_values"]
423
- _supports_flash_attn_3 = True
424
  _supports_flash_attn_2 = True
425
  _supports_sdpa = True
426
- _supports_flex_attn = True
427
  _supports_cache_class = True
428
  _supports_quantized_cache = True
429
  _supports_static_cache = True
430
- _supports_attention_backend = True
431
 
432
  def _init_weights(self, module):
433
  std = self.config.initializer_range
@@ -439,12 +839,95 @@ class LlamaPreTrainedModel(PreTrainedModel):
439
  module.weight.data.normal_(mean=0.0, std=std)
440
  if module.padding_idx is not None:
441
  module.weight.data[module.padding_idx].zero_()
442
- elif isinstance(module, LlamaRMSNorm):
443
- module.weight.data.fill_(1.0)
444
 
445
 
446
- @auto_docstring
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  class LlamaModel(LlamaPreTrainedModel):
 
 
 
 
 
 
 
448
  def __init__(self, config: LlamaConfig):
449
  super().__init__(config)
450
  self.padding_idx = config.pad_token_id
@@ -467,26 +950,26 @@ class LlamaModel(LlamaPreTrainedModel):
467
  def set_input_embeddings(self, value):
468
  self.embed_tokens = value
469
 
470
- @can_return_tuple
471
- @auto_docstring
472
  def forward(
473
  self,
474
- input_ids: Optional[torch.LongTensor] = None,
475
  attention_mask: Optional[torch.Tensor] = None,
476
  position_ids: Optional[torch.LongTensor] = None,
477
- past_key_values: Optional[Cache] = None,
478
  inputs_embeds: Optional[torch.FloatTensor] = None,
479
  use_cache: Optional[bool] = None,
480
  output_attentions: Optional[bool] = None,
481
  output_hidden_states: Optional[bool] = None,
 
482
  cache_position: Optional[torch.LongTensor] = None,
483
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
484
- ) -> BaseModelOutputWithPast:
485
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
486
  output_hidden_states = (
487
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
488
  )
489
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
490
 
491
  if (input_ids is None) ^ (inputs_embeds is not None):
492
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
@@ -497,34 +980,34 @@ class LlamaModel(LlamaPreTrainedModel):
497
  )
498
  use_cache = False
499
 
500
- # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
501
- if not isinstance(past_key_values, (type(None), Cache)):
502
- raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
503
-
504
  if inputs_embeds is None:
505
  inputs_embeds = self.embed_tokens(input_ids)
506
 
507
- if use_cache and past_key_values is None:
508
- past_key_values = DynamicCache()
 
 
 
 
 
 
 
 
 
 
 
509
 
510
  if cache_position is None:
511
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
512
  cache_position = torch.arange(
513
  past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
514
  )
515
-
516
  if position_ids is None:
517
  position_ids = cache_position.unsqueeze(0)
518
 
519
- causal_mask = create_causal_mask(
520
- config=self.config,
521
- input_embeds=inputs_embeds,
522
- attention_mask=attention_mask,
523
- cache_position=cache_position,
524
- past_key_values=past_key_values,
525
- position_ids=position_ids,
526
  )
527
-
528
  hidden_states = inputs_embeds
529
 
530
  # create position embeddings to be shared across the decoder layers
@@ -533,25 +1016,41 @@ class LlamaModel(LlamaPreTrainedModel):
533
  # decoder layers
534
  all_hidden_states = () if output_hidden_states else None
535
  all_self_attns = () if output_attentions else None
 
536
 
537
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
538
  if output_hidden_states:
539
  all_hidden_states += (hidden_states,)
540
 
541
- layer_outputs = decoder_layer(
542
- hidden_states,
543
- attention_mask=causal_mask,
544
- position_ids=position_ids,
545
- past_key_value=past_key_values,
546
- output_attentions=output_attentions,
547
- use_cache=use_cache,
548
- cache_position=cache_position,
549
- position_embeddings=position_embeddings,
550
- **flash_attn_kwargs,
551
- )
 
 
 
 
 
 
 
 
 
 
 
 
552
 
553
  hidden_states = layer_outputs[0]
554
 
 
 
 
555
  if output_attentions:
556
  all_self_attns += (layer_outputs[1],)
557
 
@@ -561,22 +1060,143 @@ class LlamaModel(LlamaPreTrainedModel):
561
  if output_hidden_states:
562
  all_hidden_states += (hidden_states,)
563
 
 
 
 
 
 
 
564
  return BaseModelOutputWithPast(
565
  last_hidden_state=hidden_states,
566
- past_key_values=past_key_values if use_cache else None,
567
  hidden_states=all_hidden_states,
568
  attentions=all_self_attns,
569
  )
570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
 
572
- class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
 
574
 
575
- @auto_docstring
576
  class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
577
  _tied_weights_keys = ["lm_head.weight"]
578
- _tp_plan = {"lm_head": "colwise_rep"}
579
- _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
580
 
581
  def __init__(self, config):
582
  super().__init__(config)
@@ -605,28 +1225,37 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
605
  def get_decoder(self):
606
  return self.model
607
 
608
- @can_return_tuple
609
- @auto_docstring
610
  def forward(
611
  self,
612
- input_ids: Optional[torch.LongTensor] = None,
613
  attention_mask: Optional[torch.Tensor] = None,
614
  position_ids: Optional[torch.LongTensor] = None,
615
- past_key_values: Optional[Cache] = None,
616
  inputs_embeds: Optional[torch.FloatTensor] = None,
617
  labels: Optional[torch.LongTensor] = None,
618
  use_cache: Optional[bool] = None,
619
  output_attentions: Optional[bool] = None,
620
  output_hidden_states: Optional[bool] = None,
 
621
  cache_position: Optional[torch.LongTensor] = None,
622
- logits_to_keep: Union[int, torch.Tensor] = 0,
623
- **kwargs: Unpack[KwargsForCausalLM],
624
- ) -> CausalLMOutputWithPast:
625
  r"""
626
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
627
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
628
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
629
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
 
 
 
 
 
 
 
 
630
 
631
  Example:
632
 
@@ -648,9 +1277,10 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
648
  output_hidden_states = (
649
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
650
  )
 
651
 
652
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
653
- outputs: BaseModelOutputWithPast = self.model(
654
  input_ids=input_ids,
655
  attention_mask=attention_mask,
656
  position_ids=position_ids,
@@ -659,18 +1289,26 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
659
  use_cache=use_cache,
660
  output_attentions=output_attentions,
661
  output_hidden_states=output_hidden_states,
 
662
  cache_position=cache_position,
663
- **kwargs,
664
  )
665
 
666
- hidden_states = outputs.last_hidden_state
667
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
668
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
669
- logits = self.lm_head(hidden_states[:, slice_indices, :])
 
 
 
 
670
 
671
  loss = None
672
  if labels is not None:
673
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
 
 
 
 
674
 
675
  return CausalLMOutputWithPast(
676
  loss=loss,
@@ -681,8 +1319,8 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
681
  )
682
 
683
 
684
- @auto_docstring(
685
- custom_intro="""
686
  The LLaMa Model transformer with a sequence classification head on top (linear layer).
687
 
688
  [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
@@ -693,7 +1331,8 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
693
  no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
694
  padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
695
  each row of the batch).
696
- """
 
697
  )
698
  class LlamaForSequenceClassification(LlamaPreTrainedModel):
699
  def __init__(self, config):
@@ -711,28 +1350,29 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
711
  def set_input_embeddings(self, value):
712
  self.model.embed_tokens = value
713
 
714
- @can_return_tuple
715
- @auto_docstring
716
  def forward(
717
  self,
718
  input_ids: Optional[torch.LongTensor] = None,
719
  attention_mask: Optional[torch.Tensor] = None,
720
  position_ids: Optional[torch.LongTensor] = None,
721
- past_key_values: Optional[Cache] = None,
722
  inputs_embeds: Optional[torch.FloatTensor] = None,
723
  labels: Optional[torch.LongTensor] = None,
724
  use_cache: Optional[bool] = None,
725
  output_attentions: Optional[bool] = None,
726
  output_hidden_states: Optional[bool] = None,
727
- ) -> SequenceClassifierOutputWithPast:
 
728
  r"""
729
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
730
  Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
731
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
732
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
733
  """
 
734
 
735
- transformer_outputs: BaseModelOutputWithPast = self.model(
736
  input_ids,
737
  attention_mask=attention_mask,
738
  position_ids=position_ids,
@@ -741,8 +1381,9 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
741
  use_cache=use_cache,
742
  output_attentions=output_attentions,
743
  output_hidden_states=output_hidden_states,
 
744
  )
745
- hidden_states = transformer_outputs.last_hidden_state
746
  logits = self.score(hidden_states)
747
 
748
  if input_ids is not None:
@@ -753,25 +1394,26 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
753
  if self.config.pad_token_id is None and batch_size != 1:
754
  raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
755
  if self.config.pad_token_id is None:
756
- last_non_pad_token = -1
757
- elif input_ids is not None:
758
- # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
759
- non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
760
- token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
761
- last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
762
  else:
763
- last_non_pad_token = -1
764
- logger.warning_once(
765
- f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
766
- "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
767
- )
 
 
768
 
769
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
770
 
771
  loss = None
772
  if labels is not None:
773
  loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
774
 
 
 
 
 
775
  return SequenceClassifierOutputWithPast(
776
  loss=loss,
777
  logits=pooled_logits,
@@ -781,7 +1423,13 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
781
  )
782
 
783
 
784
- @auto_docstring
 
 
 
 
 
 
785
  class LlamaForQuestionAnswering(LlamaPreTrainedModel):
786
  base_model_prefix = "transformer"
787
 
@@ -800,22 +1448,34 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
800
  def set_input_embeddings(self, value):
801
  self.transformer.embed_tokens = value
802
 
803
- @can_return_tuple
804
- @auto_docstring
805
  def forward(
806
  self,
807
  input_ids: Optional[torch.LongTensor] = None,
808
- attention_mask: Optional[torch.Tensor] = None,
809
  position_ids: Optional[torch.LongTensor] = None,
810
- past_key_values: Optional[Cache] = None,
811
  inputs_embeds: Optional[torch.FloatTensor] = None,
812
  start_positions: Optional[torch.LongTensor] = None,
813
  end_positions: Optional[torch.LongTensor] = None,
814
  output_attentions: Optional[bool] = None,
815
  output_hidden_states: Optional[bool] = None,
 
816
  **kwargs,
817
- ) -> QuestionAnsweringModelOutput:
818
- outputs: BaseModelOutputWithPast = self.transformer(
 
 
 
 
 
 
 
 
 
 
 
 
819
  input_ids,
820
  attention_mask=attention_mask,
821
  position_ids=position_ids,
@@ -823,9 +1483,10 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
823
  inputs_embeds=inputs_embeds,
824
  output_attentions=output_attentions,
825
  output_hidden_states=output_hidden_states,
 
826
  )
827
 
828
- sequence_output = outputs.last_hidden_state
829
 
830
  logits = self.qa_outputs(sequence_output)
831
  start_logits, end_logits = logits.split(1, dim=-1)
@@ -836,6 +1497,10 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
836
  if start_positions is not None and end_positions is not None:
837
  loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
838
 
 
 
 
 
839
  return QuestionAnsweringModelOutput(
840
  loss=loss,
841
  start_logits=start_logits,
@@ -845,7 +1510,13 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
845
  )
846
 
847
 
848
- @auto_docstring
 
 
 
 
 
 
849
  class LlamaForTokenClassification(LlamaPreTrainedModel):
850
  def __init__(self, config):
851
  super().__init__(config)
@@ -869,28 +1540,34 @@ class LlamaForTokenClassification(LlamaPreTrainedModel):
869
  def set_input_embeddings(self, value):
870
  self.model.embed_tokens = value
871
 
872
- @can_return_tuple
873
- @auto_docstring
 
 
 
 
874
  def forward(
875
  self,
876
  input_ids: Optional[torch.LongTensor] = None,
877
  attention_mask: Optional[torch.Tensor] = None,
878
  position_ids: Optional[torch.LongTensor] = None,
879
- past_key_values: Optional[Cache] = None,
880
  inputs_embeds: Optional[torch.FloatTensor] = None,
881
  labels: Optional[torch.LongTensor] = None,
882
  use_cache: Optional[bool] = None,
883
  output_attentions: Optional[bool] = None,
884
  output_hidden_states: Optional[bool] = None,
885
- ) -> TokenClassifierOutput:
 
886
  r"""
887
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
888
  Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
889
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
890
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
891
  """
 
892
 
893
- outputs: BaseModelOutputWithPast = self.model(
894
  input_ids,
895
  attention_mask=attention_mask,
896
  position_ids=position_ids,
@@ -899,8 +1576,9 @@ class LlamaForTokenClassification(LlamaPreTrainedModel):
899
  use_cache=use_cache,
900
  output_attentions=output_attentions,
901
  output_hidden_states=output_hidden_states,
 
902
  )
903
- sequence_output = outputs.last_hidden_state
904
  sequence_output = self.dropout(sequence_output)
905
  logits = self.score(sequence_output)
906
 
@@ -908,19 +1586,13 @@ class LlamaForTokenClassification(LlamaPreTrainedModel):
908
  if labels is not None:
909
  loss = self.loss_function(logits, labels, self.config)
910
 
 
 
 
 
911
  return TokenClassifierOutput(
912
  loss=loss,
913
  logits=logits,
914
  hidden_states=outputs.hidden_states,
915
  attentions=outputs.attentions,
916
  )
917
-
918
-
919
- __all__ = [
920
- "LlamaForCausalLM",
921
- "LlamaModel",
922
- "LlamaPreTrainedModel",
923
- "LlamaForSequenceClassification",
924
- "LlamaForQuestionAnswering",
925
- "LlamaForTokenClassification",
926
- ]
 
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
20
+ import math
21
+ from typing import List, Optional, Tuple, Union
22
 
23
  import torch
24
+ import torch.nn.functional as F
25
  import torch.utils.checkpoint
26
  from torch import nn
27
 
28
  from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
30
  from transformers.generation import GenerationMixin
31
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
32
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
 
 
33
  from transformers.modeling_outputs import (
34
  BaseModelOutputWithPast,
35
  CausalLMOutputWithPast,
 
37
  SequenceClassifierOutputWithPast,
38
  TokenClassifierOutput,
39
  )
40
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
41
+ from transformers.modeling_utils import PreTrainedModel
42
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
43
+ from transformers.utils import (
44
+ add_code_sample_docstrings,
45
+ add_start_docstrings,
46
+ add_start_docstrings_to_model_forward,
47
+ is_flash_attn_greater_or_equal_2_10,
48
+ logging,
49
+ replace_return_docstrings,
50
+ )
51
  from .configuration_llama import LlamaConfig
52
 
53
 
54
  logger = logging.get_logger(__name__)
55
 
56
+ _CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf"
57
+ _CONFIG_FOR_DOC = "LlamaConfig"
58
+
59
 
60
  def get_quantizer(quant_type="none", bit=4, group_size=128):
61
  if quant_type == "intsym":
 
150
  x = x + self.bias
151
  return x
152
 
 
153
  class LlamaRMSNorm(nn.Module):
154
  def __init__(self, hidden_size, eps=1e-6):
155
  """
 
170
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
171
 
172
 
173
+ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
174
+
175
+
176
  class LlamaRotaryEmbedding(nn.Module):
177
+ def __init__(
178
+ self,
179
+ dim=None,
180
+ max_position_embeddings=2048,
181
+ base=10000,
182
+ device=None,
183
+ scaling_factor=1.0,
184
+ rope_type="default",
185
+ config: Optional[LlamaConfig] = None,
186
+ ):
187
  super().__init__()
188
+ # TODO (joao): remove the `if` below, only used for BC
189
+ self.rope_kwargs = {}
190
+ if config is None:
191
+ logger.warning_once(
192
+ "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
193
+ "`config` argument. All other arguments will be removed in v4.46"
194
+ )
195
+ self.rope_kwargs = {
196
+ "rope_type": rope_type,
197
+ "factor": scaling_factor,
198
+ "dim": dim,
199
+ "base": base,
200
+ "max_position_embeddings": max_position_embeddings,
201
+ }
202
+ self.rope_type = rope_type
203
+ self.max_seq_len_cached = max_position_embeddings
204
+ self.original_max_seq_len = max_position_embeddings
205
  else:
206
+ # BC: "rope_type" was originally "type"
207
+ if config.rope_scaling is not None:
208
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
209
+ else:
210
+ self.rope_type = "default"
211
+ self.max_seq_len_cached = config.max_position_embeddings
212
+ self.original_max_seq_len = config.max_position_embeddings
213
 
214
  self.config = config
215
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
216
 
217
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
218
  self.register_buffer("inv_freq", inv_freq, persistent=False)
219
  self.original_inv_freq = self.inv_freq
220
 
221
+ def _dynamic_frequency_update(self, position_ids, device):
222
+ """
223
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
224
+ 1 - growing beyond the cached sequence length (allow scaling)
225
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
226
+ """
227
+ seq_len = torch.max(position_ids) + 1
228
+ if seq_len > self.max_seq_len_cached: # growth
229
+ inv_freq, self.attention_scaling = self.rope_init_fn(
230
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
231
+ )
232
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
233
+ self.max_seq_len_cached = seq_len
234
+
235
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
236
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
237
+ self.max_seq_len_cached = self.original_max_seq_len
238
+
239
  @torch.no_grad()
 
240
  def forward(self, x, position_ids):
241
+ if "dynamic" in self.rope_type:
242
+ self._dynamic_frequency_update(position_ids, device=x.device)
243
 
244
+ # Core RoPE block
245
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
246
+ position_ids_expanded = position_ids[:, None, :].float()
247
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
248
+ device_type = x.device.type
249
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
250
+ with torch.autocast(device_type=device_type, enabled=False):
251
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
252
  emb = torch.cat((freqs, freqs), dim=-1)
253
+ cos = emb.cos()
254
+ sin = emb.sin()
255
+
256
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
257
+ cos = cos * self.attention_scaling
258
+ sin = sin * self.attention_scaling
259
 
260
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
261
 
262
 
263
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
264
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
265
+
266
+ def __init__(self, *args, **kwargs):
267
+ logger.warning_once(
268
+ "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
269
+ "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
270
+ )
271
+ kwargs["rope_type"] = "linear"
272
+ super().__init__(*args, **kwargs)
273
+
274
+
275
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
276
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
277
+
278
+ def __init__(self, *args, **kwargs):
279
+ logger.warning_once(
280
+ "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
281
+ "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
282
+ "__init__)."
283
+ )
284
+ kwargs["rope_type"] = "dynamic"
285
+ super().__init__(*args, **kwargs)
286
+
287
+
288
  def rotate_half(x):
289
  """Rotates half the hidden dims of the input."""
290
  x1 = x[..., : x.shape[-1] // 2]
 
328
  self.gate_proj = LinearQuantizer(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, quant_type="ternary", bit=4, group_size=-1)
329
  self.up_proj = LinearQuantizer(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, quant_type="ternary", bit=4, group_size=-1)
330
  self.down_proj = LinearQuantizer(self.intermediate_size, self.hidden_size, bias=config.mlp_bias, quant_type="ternary", bit=4, group_size=-1)
 
 
 
331
  self.act_fn = ACT2FN[config.hidden_act]
332
 
333
  def forward(self, x):
334
+ if self.config.pretraining_tp > 1:
335
+ slice = self.intermediate_size // self.config.pretraining_tp
336
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
337
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
338
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
339
+
340
+ gate_proj = torch.cat(
341
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
342
+ )
343
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
344
+
345
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
346
+ down_proj = [
347
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
348
+ ]
349
+ down_proj = sum(down_proj)
350
+ else:
351
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
352
+
353
  return down_proj
354
 
355
 
 
365
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
366
 
367
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  class LlamaAttention(nn.Module):
369
  """Multi-headed attention from 'Attention Is All You Need' paper"""
370
 
371
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
372
  super().__init__()
373
  self.config = config
374
  self.layer_idx = layer_idx
375
+ if layer_idx is None:
376
+ logger.warning_once(
377
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
378
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
379
+ "when creating this class."
380
+ )
381
+
382
  self.attention_dropout = config.attention_dropout
383
+ self.hidden_size = config.hidden_size
384
+ self.num_heads = config.num_attention_heads
385
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
386
+ self.num_key_value_heads = config.num_key_value_heads
387
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
388
+ self.max_position_embeddings = config.max_position_embeddings
389
+ self.rope_theta = config.rope_theta
390
  self.is_causal = True
391
 
392
+ self.q_proj = LinearQuantizer(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias, quant_type="ternary", bit=4, group_size=-1)
393
+ self.k_proj = LinearQuantizer(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, quant_type="ternary", bit=4, group_size=-1)
394
+ self.v_proj = LinearQuantizer(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, quant_type="ternary", bit=4, group_size=-1)
395
+ self.o_proj = LinearQuantizer(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias, quant_type="ternary", bit=4, group_size=-1)
396
+
397
+ # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
398
+ self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
 
 
 
 
 
 
 
 
 
399
 
400
  def forward(
401
  self,
402
  hidden_states: torch.Tensor,
403
+ attention_mask: Optional[torch.Tensor] = None,
404
+ position_ids: Optional[torch.LongTensor] = None,
405
  past_key_value: Optional[Cache] = None,
406
+ output_attentions: bool = False,
407
+ use_cache: bool = False,
408
  cache_position: Optional[torch.LongTensor] = None,
409
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
410
+ **kwargs,
411
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
412
+ bsz, q_len, _ = hidden_states.size()
413
 
414
+ if self.config.pretraining_tp > 1:
415
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
416
+ query_slices = self.q_proj.weight.split(
417
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
418
+ )
419
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
420
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
421
 
422
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
423
+ query_states = torch.cat(query_states, dim=-1)
424
+
425
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
426
+ key_states = torch.cat(key_states, dim=-1)
427
+
428
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
429
+ value_states = torch.cat(value_states, dim=-1)
430
+
431
+ else:
432
+ query_states = self.q_proj(hidden_states)
433
+ key_states = self.k_proj(hidden_states)
434
+ value_states = self.v_proj(hidden_states)
435
+
436
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
437
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
438
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
439
+
440
+ if position_embeddings is None:
441
+ logger.warning_once(
442
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
443
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
444
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
445
+ "removed and `position_embeddings` will be mandatory."
446
+ )
447
+ cos, sin = self.rotary_emb(value_states, position_ids)
448
+ else:
449
+ cos, sin = position_embeddings
450
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
451
 
452
  if past_key_value is not None:
 
454
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
455
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
456
 
457
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
458
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
459
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
460
+
461
+ if attention_mask is not None: # no matter the length, we just slice it
462
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
463
+ attn_weights = attn_weights + causal_mask
464
+
465
+ # upcast attention to fp32
466
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
467
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
468
+ attn_output = torch.matmul(attn_weights, value_states)
469
+
470
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
471
+ raise ValueError(
472
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
473
+ f" {attn_output.size()}"
474
+ )
475
 
476
+ attn_output = attn_output.transpose(1, 2).contiguous()
477
+
478
+ attn_output = attn_output.reshape(bsz, q_len, -1)
479
+
480
+ if self.config.pretraining_tp > 1:
481
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
482
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
483
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
484
+ else:
485
+ attn_output = self.o_proj(attn_output)
486
+
487
+ if not output_attentions:
488
+ attn_weights = None
489
+
490
+ return attn_output, attn_weights, past_key_value
491
+
492
+
493
+ class LlamaFlashAttention2(LlamaAttention):
494
+ """
495
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
496
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
497
+ flash attention and deal with padding tokens in case the input contains any of them.
498
+ """
499
+
500
+ def __init__(self, *args, **kwargs):
501
+ super().__init__(*args, **kwargs)
502
+
503
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
504
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
505
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
506
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
507
+
508
+ def forward(
509
+ self,
510
+ hidden_states: torch.Tensor,
511
+ attention_mask: Optional[torch.LongTensor] = None,
512
+ position_ids: Optional[torch.LongTensor] = None,
513
+ past_key_value: Optional[Cache] = None,
514
+ output_attentions: bool = False,
515
+ use_cache: bool = False,
516
+ cache_position: Optional[torch.LongTensor] = None,
517
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
518
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
519
+ if isinstance(past_key_value, StaticCache):
520
+ raise ValueError(
521
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
522
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
523
+ )
524
+
525
+ output_attentions = False
526
+
527
+ bsz, q_len, _ = hidden_states.size()
528
+
529
+ query_states = self.q_proj(hidden_states)
530
+ key_states = self.k_proj(hidden_states)
531
+ value_states = self.v_proj(hidden_states)
532
+
533
+ # Flash attention requires the input to have the shape
534
+ # batch_size x seq_length x head_dim x hidden_dim
535
+ # therefore we just need to keep the original shape
536
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
537
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
538
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
539
+
540
+ if position_embeddings is None:
541
+ logger.warning_once(
542
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
543
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
544
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
545
+ "removed and `position_embeddings` will be mandatory."
546
+ )
547
+ cos, sin = self.rotary_emb(value_states, position_ids)
548
+ else:
549
+ cos, sin = position_embeddings
550
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
551
+
552
+ if past_key_value is not None:
553
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
554
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
555
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
556
+
557
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
558
+ # to be able to avoid many of these transpose/reshape/view.
559
+ query_states = query_states.transpose(1, 2)
560
+ key_states = key_states.transpose(1, 2)
561
+ value_states = value_states.transpose(1, 2)
562
+
563
+ dropout_rate = self.attention_dropout if self.training else 0.0
564
+
565
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
566
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
567
+ # cast them back in the correct dtype just to be sure everything works as expected.
568
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
569
+ # in fp32. (LlamaRMSNorm handles it correctly)
570
+
571
+ input_dtype = query_states.dtype
572
+ if input_dtype == torch.float32:
573
+ if torch.is_autocast_enabled():
574
+ target_dtype = torch.get_autocast_gpu_dtype()
575
+ # Handle the case where the model is quantized
576
+ elif hasattr(self.config, "_pre_quantization_dtype"):
577
+ target_dtype = self.config._pre_quantization_dtype
578
+ else:
579
+ target_dtype = self.q_proj.weight.dtype
580
+
581
+ logger.warning_once(
582
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
583
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
584
+ f" {target_dtype}."
585
+ )
586
+
587
+ query_states = query_states.to(target_dtype)
588
+ key_states = key_states.to(target_dtype)
589
+ value_states = value_states.to(target_dtype)
590
+
591
+ attn_output = _flash_attention_forward(
592
  query_states,
593
  key_states,
594
  value_states,
595
  attention_mask,
596
+ q_len,
597
+ position_ids=position_ids,
598
+ dropout=dropout_rate,
599
+ sliding_window=getattr(self, "sliding_window", None),
600
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
601
+ is_causal=self.is_causal,
602
  )
603
 
604
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
605
  attn_output = self.o_proj(attn_output)
 
606
 
607
+ if not output_attentions:
608
+ attn_weights = None
609
 
610
+ return attn_output, attn_weights, past_key_value
611
+
612
+
613
+ class LlamaSdpaAttention(LlamaAttention):
614
+ """
615
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
616
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
617
+ SDPA API.
618
+ """
619
+
620
+ # Adapted from LlamaAttention.forward
621
+ def forward(
622
+ self,
623
+ hidden_states: torch.Tensor,
624
+ attention_mask: Optional[torch.Tensor] = None,
625
+ position_ids: Optional[torch.LongTensor] = None,
626
+ past_key_value: Optional[Cache] = None,
627
+ output_attentions: bool = False,
628
+ use_cache: bool = False,
629
+ cache_position: Optional[torch.LongTensor] = None,
630
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
631
+ **kwargs,
632
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
633
+ if output_attentions:
634
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
635
+ logger.warning_once(
636
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
637
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
638
+ )
639
+ return super().forward(
640
+ hidden_states=hidden_states,
641
+ attention_mask=attention_mask,
642
+ position_ids=position_ids,
643
+ past_key_value=past_key_value,
644
+ output_attentions=output_attentions,
645
+ use_cache=use_cache,
646
+ cache_position=cache_position,
647
+ position_embeddings=position_embeddings,
648
+ )
649
+
650
+ bsz, q_len, _ = hidden_states.size()
651
+
652
+ query_states = self.q_proj(hidden_states)
653
+ key_states = self.k_proj(hidden_states)
654
+ value_states = self.v_proj(hidden_states)
655
+
656
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
657
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
658
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
659
+
660
+ if position_embeddings is None:
661
+ logger.warning_once(
662
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
663
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
664
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
665
+ "removed and `position_embeddings` will be mandatory."
666
+ )
667
+ cos, sin = self.rotary_emb(value_states, position_ids)
668
+ else:
669
+ cos, sin = position_embeddings
670
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
671
+
672
+ if past_key_value is not None:
673
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
674
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
675
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
676
+
677
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
678
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
679
+
680
+ causal_mask = attention_mask
681
+ if attention_mask is not None:
682
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
683
+
684
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
685
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
686
+ if query_states.device.type == "cuda" and causal_mask is not None:
687
+ query_states = query_states.contiguous()
688
+ key_states = key_states.contiguous()
689
+ value_states = value_states.contiguous()
690
+
691
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
692
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
693
+ is_causal = True if causal_mask is None and q_len > 1 else False
694
+
695
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
696
+ query_states,
697
+ key_states,
698
+ value_states,
699
+ attn_mask=causal_mask,
700
+ dropout_p=self.attention_dropout if self.training else 0.0,
701
+ is_causal=is_causal,
702
+ )
703
+
704
+ attn_output = attn_output.transpose(1, 2).contiguous()
705
+ attn_output = attn_output.view(bsz, q_len, -1)
706
+
707
+ attn_output = self.o_proj(attn_output)
708
+
709
+ return attn_output, None, past_key_value
710
+
711
+
712
+ LLAMA_ATTENTION_CLASSES = {
713
+ "eager": LlamaAttention,
714
+ "flash_attention_2": LlamaFlashAttention2,
715
+ "sdpa": LlamaSdpaAttention,
716
+ }
717
+
718
+
719
+ class LlamaDecoderLayer(nn.Module):
720
  def __init__(self, config: LlamaConfig, layer_idx: int):
721
  super().__init__()
722
  self.hidden_size = config.hidden_size
723
 
724
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
725
 
726
  self.mlp = LlamaMLP(config)
727
  self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
736
  output_attentions: Optional[bool] = False,
737
  use_cache: Optional[bool] = False,
738
  cache_position: Optional[torch.LongTensor] = None,
739
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
740
+ **kwargs,
741
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
742
+ """
743
+ Args:
744
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
745
+ attention_mask (`torch.FloatTensor`, *optional*):
746
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
747
+ query_sequence_length, key_sequence_length)` if default attention is used.
748
+ output_attentions (`bool`, *optional*):
749
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
750
+ returned tensors for more detail.
751
+ use_cache (`bool`, *optional*):
752
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
753
+ (see `past_key_values`).
754
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
755
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
756
+ Indices depicting the position of the input sequence tokens in the sequence
757
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
758
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
759
+ with `head_dim` being the embedding dimension of each attention head.
760
+ kwargs (`dict`, *optional*):
761
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
762
+ into the model
763
+ """
764
  residual = hidden_states
765
+
766
  hidden_states = self.input_layernorm(hidden_states)
767
 
768
  # Self Attention
769
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
770
  hidden_states=hidden_states,
771
  attention_mask=attention_mask,
772
  position_ids=position_ids,
 
786
  hidden_states = residual + hidden_states
787
 
788
  outputs = (hidden_states,)
789
+
790
  if output_attentions:
791
  outputs += (self_attn_weights,)
792
 
793
+ if use_cache:
794
+ outputs += (present_key_value,)
795
+
796
  return outputs
797
 
798
 
799
+ LLAMA_START_DOCSTRING = r"""
800
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
801
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
802
+ etc.)
803
+
804
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
805
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
806
+ and behavior.
807
+
808
+ Parameters:
809
+ config ([`LlamaConfig`]):
810
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
811
+ load the weights associated with the model, only the configuration. Check out the
812
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
813
+ """
814
+
815
+
816
+ @add_start_docstrings(
817
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
818
+ LLAMA_START_DOCSTRING,
819
+ )
820
  class LlamaPreTrainedModel(PreTrainedModel):
821
  config_class = LlamaConfig
822
  base_model_prefix = "model"
823
  supports_gradient_checkpointing = True
824
  _no_split_modules = ["LlamaDecoderLayer"]
825
  _skip_keys_device_placement = ["past_key_values"]
 
826
  _supports_flash_attn_2 = True
827
  _supports_sdpa = True
 
828
  _supports_cache_class = True
829
  _supports_quantized_cache = True
830
  _supports_static_cache = True
 
831
 
832
  def _init_weights(self, module):
833
  std = self.config.initializer_range
 
839
  module.weight.data.normal_(mean=0.0, std=std)
840
  if module.padding_idx is not None:
841
  module.weight.data[module.padding_idx].zero_()
 
 
842
 
843
 
844
+ LLAMA_INPUTS_DOCSTRING = r"""
845
+ Args:
846
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
847
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
848
+ it.
849
+
850
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
851
+ [`PreTrainedTokenizer.__call__`] for details.
852
+
853
+ [What are input IDs?](../glossary#input-ids)
854
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
855
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
856
+
857
+ - 1 for tokens that are **not masked**,
858
+ - 0 for tokens that are **masked**.
859
+
860
+ [What are attention masks?](../glossary#attention-mask)
861
+
862
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
863
+ [`PreTrainedTokenizer.__call__`] for details.
864
+
865
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
866
+ `past_key_values`).
867
+
868
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
869
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
870
+ information on the default strategy.
871
+
872
+ - 1 indicates the head is **not masked**,
873
+ - 0 indicates the head is **masked**.
874
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
875
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
876
+ config.n_positions - 1]`.
877
+
878
+ [What are position IDs?](../glossary#position-ids)
879
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
880
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
881
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
882
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
883
+
884
+ Two formats are allowed:
885
+ - a [`~cache_utils.Cache`] instance, see our
886
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
887
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
888
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
889
+ cache format.
890
+
891
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
892
+ legacy cache format will be returned.
893
+
894
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
895
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
896
+ of shape `(batch_size, sequence_length)`.
897
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
898
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
899
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
900
+ model's internal embedding lookup matrix.
901
+ use_cache (`bool`, *optional*):
902
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
903
+ `past_key_values`).
904
+ output_attentions (`bool`, *optional*):
905
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
906
+ tensors for more detail.
907
+ output_hidden_states (`bool`, *optional*):
908
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
909
+ more detail.
910
+ return_dict (`bool`, *optional*):
911
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
912
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
913
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
914
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
915
+ the complete sequence length.
916
+ """
917
+
918
+
919
+ @add_start_docstrings(
920
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
921
+ LLAMA_START_DOCSTRING,
922
+ )
923
  class LlamaModel(LlamaPreTrainedModel):
924
+ """
925
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
926
+
927
+ Args:
928
+ config: LlamaConfig
929
+ """
930
+
931
  def __init__(self, config: LlamaConfig):
932
  super().__init__(config)
933
  self.padding_idx = config.pad_token_id
 
950
  def set_input_embeddings(self, value):
951
  self.embed_tokens = value
952
 
953
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
 
954
  def forward(
955
  self,
956
+ input_ids: torch.LongTensor = None,
957
  attention_mask: Optional[torch.Tensor] = None,
958
  position_ids: Optional[torch.LongTensor] = None,
959
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
960
  inputs_embeds: Optional[torch.FloatTensor] = None,
961
  use_cache: Optional[bool] = None,
962
  output_attentions: Optional[bool] = None,
963
  output_hidden_states: Optional[bool] = None,
964
+ return_dict: Optional[bool] = None,
965
  cache_position: Optional[torch.LongTensor] = None,
966
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
 
967
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
968
  output_hidden_states = (
969
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
970
  )
971
  use_cache = use_cache if use_cache is not None else self.config.use_cache
972
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
973
 
974
  if (input_ids is None) ^ (inputs_embeds is not None):
975
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
 
980
  )
981
  use_cache = False
982
 
 
 
 
 
983
  if inputs_embeds is None:
984
  inputs_embeds = self.embed_tokens(input_ids)
985
 
986
+ # kept for BC (non `Cache` `past_key_values` inputs)
987
+ return_legacy_cache = False
988
+ if use_cache and not isinstance(past_key_values, Cache):
989
+ return_legacy_cache = True
990
+ if past_key_values is None:
991
+ past_key_values = DynamicCache()
992
+ else:
993
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
994
+ logger.warning_once(
995
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
996
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
997
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
998
+ )
999
 
1000
  if cache_position is None:
1001
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1002
  cache_position = torch.arange(
1003
  past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1004
  )
 
1005
  if position_ids is None:
1006
  position_ids = cache_position.unsqueeze(0)
1007
 
1008
+ causal_mask = self._update_causal_mask(
1009
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
 
 
 
 
 
1010
  )
 
1011
  hidden_states = inputs_embeds
1012
 
1013
  # create position embeddings to be shared across the decoder layers
 
1016
  # decoder layers
1017
  all_hidden_states = () if output_hidden_states else None
1018
  all_self_attns = () if output_attentions else None
1019
+ next_decoder_cache = None
1020
 
1021
+ for decoder_layer in self.layers:
1022
  if output_hidden_states:
1023
  all_hidden_states += (hidden_states,)
1024
 
1025
+ if self.gradient_checkpointing and self.training:
1026
+ layer_outputs = self._gradient_checkpointing_func(
1027
+ decoder_layer.__call__,
1028
+ hidden_states,
1029
+ causal_mask,
1030
+ position_ids,
1031
+ past_key_values,
1032
+ output_attentions,
1033
+ use_cache,
1034
+ cache_position,
1035
+ position_embeddings,
1036
+ )
1037
+ else:
1038
+ layer_outputs = decoder_layer(
1039
+ hidden_states,
1040
+ attention_mask=causal_mask,
1041
+ position_ids=position_ids,
1042
+ past_key_value=past_key_values,
1043
+ output_attentions=output_attentions,
1044
+ use_cache=use_cache,
1045
+ cache_position=cache_position,
1046
+ position_embeddings=position_embeddings,
1047
+ )
1048
 
1049
  hidden_states = layer_outputs[0]
1050
 
1051
+ if use_cache:
1052
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1053
+
1054
  if output_attentions:
1055
  all_self_attns += (layer_outputs[1],)
1056
 
 
1060
  if output_hidden_states:
1061
  all_hidden_states += (hidden_states,)
1062
 
1063
+ next_cache = next_decoder_cache if use_cache else None
1064
+ if return_legacy_cache:
1065
+ next_cache = next_cache.to_legacy_cache()
1066
+
1067
+ if not return_dict:
1068
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1069
  return BaseModelOutputWithPast(
1070
  last_hidden_state=hidden_states,
1071
+ past_key_values=next_cache,
1072
  hidden_states=all_hidden_states,
1073
  attentions=all_self_attns,
1074
  )
1075
 
1076
+ def _update_causal_mask(
1077
+ self,
1078
+ attention_mask: torch.Tensor,
1079
+ input_tensor: torch.Tensor,
1080
+ cache_position: torch.Tensor,
1081
+ past_key_values: Cache,
1082
+ output_attentions: bool,
1083
+ ):
1084
+ if self.config._attn_implementation == "flash_attention_2":
1085
+ if attention_mask is not None and 0.0 in attention_mask:
1086
+ return attention_mask
1087
+ return None
1088
+
1089
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1090
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1091
+ # to infer the attention mask.
1092
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1093
+ using_static_cache = isinstance(past_key_values, StaticCache)
1094
+
1095
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1096
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
1097
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1098
+ attention_mask,
1099
+ inputs_embeds=input_tensor,
1100
+ past_key_values_length=past_seen_tokens,
1101
+ is_training=self.training,
1102
+ ):
1103
+ return None
1104
+
1105
+ dtype, device = input_tensor.dtype, input_tensor.device
1106
+ sequence_length = input_tensor.shape[1]
1107
+ if using_static_cache:
1108
+ target_length = past_key_values.get_max_cache_shape()
1109
+ else:
1110
+ target_length = (
1111
+ attention_mask.shape[-1]
1112
+ if isinstance(attention_mask, torch.Tensor)
1113
+ else past_seen_tokens + sequence_length + 1
1114
+ )
1115
+
1116
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1117
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1118
+ attention_mask,
1119
+ sequence_length=sequence_length,
1120
+ target_length=target_length,
1121
+ dtype=dtype,
1122
+ device=device,
1123
+ cache_position=cache_position,
1124
+ batch_size=input_tensor.shape[0],
1125
+ )
1126
 
1127
+ if (
1128
+ self.config._attn_implementation == "sdpa"
1129
+ and attention_mask is not None
1130
+ and attention_mask.device.type == "cuda"
1131
+ and not output_attentions
1132
+ ):
1133
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1134
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1135
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1136
+ min_dtype = torch.finfo(dtype).min
1137
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1138
+
1139
+ return causal_mask
1140
+
1141
+ @staticmethod
1142
+ def _prepare_4d_causal_attention_mask_with_cache_position(
1143
+ attention_mask: torch.Tensor,
1144
+ sequence_length: int,
1145
+ target_length: int,
1146
+ dtype: torch.dtype,
1147
+ device: torch.device,
1148
+ cache_position: torch.Tensor,
1149
+ batch_size: int,
1150
+ **kwargs,
1151
+ ):
1152
+ """
1153
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1154
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1155
+
1156
+ Args:
1157
+ attention_mask (`torch.Tensor`):
1158
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
1159
+ `(batch_size, 1, query_length, key_value_length)`.
1160
+ sequence_length (`int`):
1161
+ The sequence length being processed.
1162
+ target_length (`int`):
1163
+ The target length: when generating with static cache, the mask should be as long as the static cache,
1164
+ to account for the 0 padding, the part of the cache that is not filled yet.
1165
+ dtype (`torch.dtype`):
1166
+ The dtype to use for the 4D attention mask.
1167
+ device (`torch.device`):
1168
+ The device to plcae the 4D attention mask on.
1169
+ cache_position (`torch.Tensor`):
1170
+ Indices depicting the position of the input sequence tokens in the sequence.
1171
+ batch_size (`torch.Tensor`):
1172
+ Batch size.
1173
+ """
1174
+ if attention_mask is not None and attention_mask.dim() == 4:
1175
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1176
+ causal_mask = attention_mask
1177
+ else:
1178
+ min_dtype = torch.finfo(dtype).min
1179
+ causal_mask = torch.full(
1180
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
1181
+ )
1182
+ if sequence_length != 1:
1183
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1184
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1185
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1186
+ if attention_mask is not None:
1187
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1188
+ mask_length = attention_mask.shape[-1]
1189
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
1190
+ padding_mask = padding_mask == 0
1191
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1192
+ padding_mask, min_dtype
1193
+ )
1194
+
1195
+ return causal_mask
1196
 
1197
 
 
1198
  class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
1199
  _tied_weights_keys = ["lm_head.weight"]
 
 
1200
 
1201
  def __init__(self, config):
1202
  super().__init__(config)
 
1225
  def get_decoder(self):
1226
  return self.model
1227
 
1228
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1229
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1230
  def forward(
1231
  self,
1232
+ input_ids: torch.LongTensor = None,
1233
  attention_mask: Optional[torch.Tensor] = None,
1234
  position_ids: Optional[torch.LongTensor] = None,
1235
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1236
  inputs_embeds: Optional[torch.FloatTensor] = None,
1237
  labels: Optional[torch.LongTensor] = None,
1238
  use_cache: Optional[bool] = None,
1239
  output_attentions: Optional[bool] = None,
1240
  output_hidden_states: Optional[bool] = None,
1241
+ return_dict: Optional[bool] = None,
1242
  cache_position: Optional[torch.LongTensor] = None,
1243
+ num_logits_to_keep: int = 0,
1244
+ **loss_kwargs,
1245
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1246
  r"""
1247
+ Args:
1248
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1249
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1250
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1251
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1252
+
1253
+ num_logits_to_keep (`int`, *optional*):
1254
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1255
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1256
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1257
+
1258
+ Returns:
1259
 
1260
  Example:
1261
 
 
1277
  output_hidden_states = (
1278
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1279
  )
1280
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1281
 
1282
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1283
+ outputs = self.model(
1284
  input_ids=input_ids,
1285
  attention_mask=attention_mask,
1286
  position_ids=position_ids,
 
1289
  use_cache=use_cache,
1290
  output_attentions=output_attentions,
1291
  output_hidden_states=output_hidden_states,
1292
+ return_dict=return_dict,
1293
  cache_position=cache_position,
 
1294
  )
1295
 
1296
+ hidden_states = outputs[0]
1297
+ if self.config.pretraining_tp > 1:
1298
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1299
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1300
+ logits = torch.cat(logits, dim=-1)
1301
+ else:
1302
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1303
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
1304
 
1305
  loss = None
1306
  if labels is not None:
1307
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)
1308
+
1309
+ if not return_dict:
1310
+ output = (logits,) + outputs[1:]
1311
+ return (loss,) + output if loss is not None else output
1312
 
1313
  return CausalLMOutputWithPast(
1314
  loss=loss,
 
1319
  )
1320
 
1321
 
1322
+ @add_start_docstrings(
1323
+ """
1324
  The LLaMa Model transformer with a sequence classification head on top (linear layer).
1325
 
1326
  [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
 
1331
  no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1332
  padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1333
  each row of the batch).
1334
+ """,
1335
+ LLAMA_START_DOCSTRING,
1336
  )
1337
  class LlamaForSequenceClassification(LlamaPreTrainedModel):
1338
  def __init__(self, config):
 
1350
  def set_input_embeddings(self, value):
1351
  self.model.embed_tokens = value
1352
 
1353
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
 
1354
  def forward(
1355
  self,
1356
  input_ids: Optional[torch.LongTensor] = None,
1357
  attention_mask: Optional[torch.Tensor] = None,
1358
  position_ids: Optional[torch.LongTensor] = None,
1359
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1360
  inputs_embeds: Optional[torch.FloatTensor] = None,
1361
  labels: Optional[torch.LongTensor] = None,
1362
  use_cache: Optional[bool] = None,
1363
  output_attentions: Optional[bool] = None,
1364
  output_hidden_states: Optional[bool] = None,
1365
+ return_dict: Optional[bool] = None,
1366
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1367
  r"""
1368
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1369
  Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1370
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1371
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1372
  """
1373
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1374
 
1375
+ transformer_outputs = self.model(
1376
  input_ids,
1377
  attention_mask=attention_mask,
1378
  position_ids=position_ids,
 
1381
  use_cache=use_cache,
1382
  output_attentions=output_attentions,
1383
  output_hidden_states=output_hidden_states,
1384
+ return_dict=return_dict,
1385
  )
1386
+ hidden_states = transformer_outputs[0]
1387
  logits = self.score(hidden_states)
1388
 
1389
  if input_ids is not None:
 
1394
  if self.config.pad_token_id is None and batch_size != 1:
1395
  raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1396
  if self.config.pad_token_id is None:
1397
+ sequence_lengths = -1
 
 
 
 
 
1398
  else:
1399
+ if input_ids is not None:
1400
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1401
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1402
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1403
+ sequence_lengths = sequence_lengths.to(logits.device)
1404
+ else:
1405
+ sequence_lengths = -1
1406
 
1407
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1408
 
1409
  loss = None
1410
  if labels is not None:
1411
  loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
1412
 
1413
+ if not return_dict:
1414
+ output = (pooled_logits,) + transformer_outputs[1:]
1415
+ return ((loss,) + output) if loss is not None else output
1416
+
1417
  return SequenceClassifierOutputWithPast(
1418
  loss=loss,
1419
  logits=pooled_logits,
 
1423
  )
1424
 
1425
 
1426
+ @add_start_docstrings(
1427
+ """
1428
+ The Llama Model transformer with a span classification head on top for extractive question-answering tasks like
1429
+ SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1430
+ """,
1431
+ LLAMA_START_DOCSTRING,
1432
+ )
1433
  class LlamaForQuestionAnswering(LlamaPreTrainedModel):
1434
  base_model_prefix = "transformer"
1435
 
 
1448
  def set_input_embeddings(self, value):
1449
  self.transformer.embed_tokens = value
1450
 
1451
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
 
1452
  def forward(
1453
  self,
1454
  input_ids: Optional[torch.LongTensor] = None,
1455
+ attention_mask: Optional[torch.FloatTensor] = None,
1456
  position_ids: Optional[torch.LongTensor] = None,
1457
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1458
  inputs_embeds: Optional[torch.FloatTensor] = None,
1459
  start_positions: Optional[torch.LongTensor] = None,
1460
  end_positions: Optional[torch.LongTensor] = None,
1461
  output_attentions: Optional[bool] = None,
1462
  output_hidden_states: Optional[bool] = None,
1463
+ return_dict: Optional[bool] = None,
1464
  **kwargs,
1465
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1466
+ r"""
1467
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1468
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1469
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1470
+ are not taken into account for computing the loss.
1471
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1472
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1473
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1474
+ are not taken into account for computing the loss.
1475
+ """
1476
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1477
+
1478
+ outputs = self.transformer(
1479
  input_ids,
1480
  attention_mask=attention_mask,
1481
  position_ids=position_ids,
 
1483
  inputs_embeds=inputs_embeds,
1484
  output_attentions=output_attentions,
1485
  output_hidden_states=output_hidden_states,
1486
+ return_dict=return_dict,
1487
  )
1488
 
1489
+ sequence_output = outputs[0]
1490
 
1491
  logits = self.qa_outputs(sequence_output)
1492
  start_logits, end_logits = logits.split(1, dim=-1)
 
1497
  if start_positions is not None and end_positions is not None:
1498
  loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
1499
 
1500
+ if not return_dict:
1501
+ output = (start_logits, end_logits) + outputs[2:]
1502
+ return ((loss,) + output) if loss is not None else output
1503
+
1504
  return QuestionAnsweringModelOutput(
1505
  loss=loss,
1506
  start_logits=start_logits,
 
1510
  )
1511
 
1512
 
1513
+ @add_start_docstrings(
1514
+ """
1515
+ The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states
1516
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
1517
+ """,
1518
+ LLAMA_START_DOCSTRING,
1519
+ )
1520
  class LlamaForTokenClassification(LlamaPreTrainedModel):
1521
  def __init__(self, config):
1522
  super().__init__(config)
 
1540
  def set_input_embeddings(self, value):
1541
  self.model.embed_tokens = value
1542
 
1543
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1544
+ @add_code_sample_docstrings(
1545
+ checkpoint=_CHECKPOINT_FOR_DOC,
1546
+ output_type=TokenClassifierOutput,
1547
+ config_class=_CONFIG_FOR_DOC,
1548
+ )
1549
  def forward(
1550
  self,
1551
  input_ids: Optional[torch.LongTensor] = None,
1552
  attention_mask: Optional[torch.Tensor] = None,
1553
  position_ids: Optional[torch.LongTensor] = None,
1554
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1555
  inputs_embeds: Optional[torch.FloatTensor] = None,
1556
  labels: Optional[torch.LongTensor] = None,
1557
  use_cache: Optional[bool] = None,
1558
  output_attentions: Optional[bool] = None,
1559
  output_hidden_states: Optional[bool] = None,
1560
+ return_dict: Optional[bool] = None,
1561
+ ) -> Union[Tuple, TokenClassifierOutput]:
1562
  r"""
1563
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1564
  Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1565
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1566
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1567
  """
1568
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1569
 
1570
+ outputs = self.model(
1571
  input_ids,
1572
  attention_mask=attention_mask,
1573
  position_ids=position_ids,
 
1576
  use_cache=use_cache,
1577
  output_attentions=output_attentions,
1578
  output_hidden_states=output_hidden_states,
1579
+ return_dict=return_dict,
1580
  )
1581
+ sequence_output = outputs[0]
1582
  sequence_output = self.dropout(sequence_output)
1583
  logits = self.score(sequence_output)
1584
 
 
1586
  if labels is not None:
1587
  loss = self.loss_function(logits, labels, self.config)
1588
 
1589
+ if not return_dict:
1590
+ output = (logits,) + outputs[2:]
1591
+ return ((loss,) + output) if loss is not None else output
1592
+
1593
  return TokenClassifierOutput(
1594
  loss=loss,
1595
  logits=logits,
1596
  hidden_states=outputs.hidden_states,
1597
  attentions=outputs.attentions,
1598
  )