codys12 commited on
Commit
591280c
·
verified ·
1 Parent(s): 03c5be1

Upload modeling_hunyuan.py

Browse files
Files changed (1) hide show
  1. modeling_hunyuan.py +143 -102
modeling_hunyuan.py CHANGED
@@ -1,16 +1,5 @@
1
- # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
2
- #
3
- # Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # https://github.com/Tencent/Tencent-Hunyuan-Large/blob/main/License.docx
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
  #
15
  """ PyTorch HunYuan model."""
16
 
@@ -74,8 +63,7 @@ _CONFIG_FOR_DOC = "HunYuanConfig"
74
  def topkgating(logits: Tensor, topk: int):
75
  logits = logits.float()
76
  gates = F.softmax(logits, dim=1)
77
- # expert_capacity = topk * gates.shape[0]
78
- expert_capacity = max(topk, topk * gates.shape[0] // gates.shape[1])
79
  num_experts = int(gates.shape[1])
80
  # Top-k router probability and corresponding expert indices for each token.
81
  # Shape: [tokens_per_group, num_selected_experts].
@@ -265,7 +253,7 @@ class HunYuanRotaryEmbedding(nn.Module):
265
  self.max_position_embeddings = max_position_embeddings
266
  self.base = base
267
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
268
- # inv_freq = inv_freq.bfloat16()
269
  self.register_buffer("inv_freq", inv_freq, persistent=False)
270
 
271
  # Build here to make `torch.jit.trace` work.
@@ -277,7 +265,6 @@ class HunYuanRotaryEmbedding(nn.Module):
277
  self.max_seq_len_cached = seq_len
278
  t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
279
 
280
- self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
281
  freqs = torch.outer(t, self.inv_freq)
282
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
283
  emb = torch.cat((freqs, freqs), dim=-1).float()
@@ -286,7 +273,7 @@ class HunYuanRotaryEmbedding(nn.Module):
286
 
287
  def forward(self, x, seq_len=None):
288
  # x: [bs, num_attention_heads, seq_len, head_size]
289
- if seq_len > self.max_seq_len_cached or self.inv_freq.dtype != torch.float32:
290
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
291
 
292
  return (
@@ -411,7 +398,7 @@ class HunYuanMLP(nn.Module):
411
  self.layer_idx = layer_idx
412
  self.hidden_size = config.hidden_size
413
  if is_shared_mlp:
414
- self.intermediate_size = config.intermediate_size * config.num_shared_expert[0]
415
  else:
416
  self.intermediate_size = config.intermediate_size
417
  self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
@@ -462,66 +449,142 @@ class HunYuanTopKGate(nn.Module):
462
  if self.moe_topk == 1:
463
  gate_output = top1gating(logits, random_routing_dropped_token=self.random_routing_dropped_token)
464
  else:
465
- gate_output = topkgating(logits, self.moe_topk[0])
466
 
467
  return gate_output
468
 
469
 
470
  class HunYuanMoE(nn.Module):
 
 
 
 
 
 
 
471
  def __init__(self, config: HunYuanConfig, layer_idx: Optional[int] = None):
472
  super().__init__()
473
  self.config = config
474
  self.layer_idx = layer_idx
475
  self.moe_topk = config.moe_topk
476
  self.num_experts = config.num_experts
 
 
477
  if config.use_mixed_mlp_moe:
478
  self.shared_mlp = HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True)
 
 
479
  self.gate = HunYuanTopKGate(config, layer_idx=layer_idx)
 
 
480
  self.experts = nn.ModuleList(
481
  [HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False) for _ in range(config.num_experts)]
482
  )
483
 
484
- def forward(self, hidden_states):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
  bsz, seq_len, hidden_size = hidden_states.shape
486
 
 
487
  if self.config.use_mixed_mlp_moe:
488
  hidden_states_mlp = self.shared_mlp(hidden_states)
489
 
490
- l_moe, combine_weights, dispatch_mask, exp_counts = self.gate(hidden_states)
491
-
492
- reshaped_input = hidden_states.reshape(-1, hidden_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
 
494
- dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
 
495
 
496
- chunks = dispatched_input.chunk(self.num_experts, dim=0)
497
- expert_outputs = []
498
- for chunk, expert in zip(chunks, self.experts):
499
- expert_outputs.append(expert(chunk))
500
 
501
- expert_output = torch.cat(expert_outputs, dim=0)
502
- combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(hidden_states), expert_output)
503
- combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
 
 
 
504
 
505
- if self.config.use_mixed_mlp_moe:
506
- output = hidden_states_mlp + combined_output
507
- else:
508
- output = combined_output
 
509
 
510
- return output
 
 
 
511
 
 
 
 
 
 
512
 
513
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
514
- """
515
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
516
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
517
- """
518
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
519
- if n_rep == 1:
520
- return hidden_states
521
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
522
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
523
 
 
 
 
524
 
 
525
  class HunYuanAttention(nn.Module):
526
  """Multi-headed attention from 'Attention Is All You Need' paper"""
527
 
@@ -1064,33 +1127,18 @@ class HunYuanDecoderLayer(nn.Module):
1064
  kv_states: Optional[Tuple[torch.Tensor]] = None,
1065
  **kwargs,
1066
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1067
- """
1068
- Args:
1069
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1070
- attention_mask (`torch.FloatTensor`, *optional*):
1071
- attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
1072
- query_sequence_length, key_sequence_length)` if default attention is used.
1073
- output_attentions (`bool`, *optional*):
1074
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1075
- returned tensors for more detail.
1076
- use_cache (`bool`, *optional*):
1077
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1078
- (see `past_key_values`).
1079
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1080
- kv_states (`Tuple(torch.FloatTensor)`, *optional*): Used when CLA is enabled,
1081
- key and value states from past attention blocks
1082
- """
1083
  if "padding_mask" in kwargs:
1084
  warnings.warn(
1085
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use "
1086
  "`attention_mask` instead.`"
1087
  )
1088
-
 
1089
  residual = hidden_states
1090
-
1091
  hidden_states = self.input_layernorm(hidden_states)
1092
-
1093
- # Self Attention
1094
  hidden_states, self_attn_weights, present_key_value, kv_states = self.self_attn(
1095
  hidden_states=hidden_states,
1096
  attention_mask=attention_mask,
@@ -1102,47 +1150,40 @@ class HunYuanDecoderLayer(nn.Module):
1102
  **kwargs,
1103
  )
1104
  hidden_states = residual + hidden_states
1105
-
1106
- # Fully Connected
1107
  residual = hidden_states
1108
  hidden_states = self.post_attention_layernorm(hidden_states)
1109
- hidden_states = self.mlp(hidden_states)
1110
- hidden_states = residual + hidden_states
1111
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1112
  outputs = (hidden_states,)
1113
-
1114
  if output_attentions:
1115
  outputs += (self_attn_weights,)
1116
-
1117
  if use_cache:
1118
  outputs += (present_key_value,)
1119
-
1120
  outputs += (kv_states,)
1121
-
1122
  return outputs
1123
-
1124
-
1125
- HUNYUAN_START_DOCSTRING = r"""
1126
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1127
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1128
- etc.)
1129
-
1130
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1131
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1132
- and behavior.
1133
-
1134
- Parameters:
1135
- config ([`HunYuanConfig`]):
1136
- Model configuration class with all the parameters of the model. Initializing with a config file does not
1137
- load the weights associated with the model, only the configuration. Check out the
1138
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1139
- """
1140
-
1141
-
1142
- @add_start_docstrings(
1143
- "The bare HunYuan Model outputting raw hidden-states without any specific head on top.",
1144
- HUNYUAN_START_DOCSTRING,
1145
- )
1146
  class HunYuanPreTrainedModel(PreTrainedModel):
1147
  config_class = HunYuanConfig
1148
  base_model_prefix = "model"
@@ -1417,7 +1458,7 @@ class HunYuanModel(HunYuanPreTrainedModel):
1417
  )
1418
 
1419
 
1420
- class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel):
1421
  _tied_weights_keys = ["lm_head.weight"]
1422
 
1423
  def __init__(self, config: HunYuanConfig):
@@ -1547,7 +1588,7 @@ class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel):
1547
  if isinstance(past_key_values, Cache):
1548
  cache_length = past_key_values.get_seq_length()
1549
  past_length = past_key_values.seen_tokens
1550
- max_cache_length = past_key_values.get_max_cache_shape()
1551
  else:
1552
  cache_length = past_length = past_key_values[0][0].shape[2]
1553
  max_cache_length = None
@@ -1725,4 +1766,4 @@ class HunYuanForSequenceClassification(HunYuanPreTrainedModel):
1725
  past_key_values=transformer_outputs.past_key_values,
1726
  hidden_states=transformer_outputs.hidden_states,
1727
  attentions=transformer_outputs.attentions,
1728
- )
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Tencent Inc. All Rights Reserved.
 
 
 
 
 
 
 
 
 
 
 
3
  #
4
  """ PyTorch HunYuan model."""
5
 
 
63
  def topkgating(logits: Tensor, topk: int):
64
  logits = logits.float()
65
  gates = F.softmax(logits, dim=1)
66
+ expert_capacity = topk * gates.shape[0]
 
67
  num_experts = int(gates.shape[1])
68
  # Top-k router probability and corresponding expert indices for each token.
69
  # Shape: [tokens_per_group, num_selected_experts].
 
253
  self.max_position_embeddings = max_position_embeddings
254
  self.base = base
255
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
256
+ inv_freq = inv_freq.bfloat16()
257
  self.register_buffer("inv_freq", inv_freq, persistent=False)
258
 
259
  # Build here to make `torch.jit.trace` work.
 
265
  self.max_seq_len_cached = seq_len
266
  t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
267
 
 
268
  freqs = torch.outer(t, self.inv_freq)
269
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
270
  emb = torch.cat((freqs, freqs), dim=-1).float()
 
273
 
274
  def forward(self, x, seq_len=None):
275
  # x: [bs, num_attention_heads, seq_len, head_size]
276
+ if seq_len > self.max_seq_len_cached:
277
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
278
 
279
  return (
 
398
  self.layer_idx = layer_idx
399
  self.hidden_size = config.hidden_size
400
  if is_shared_mlp:
401
+ self.intermediate_size = config.intermediate_size * config.num_shared_expert
402
  else:
403
  self.intermediate_size = config.intermediate_size
404
  self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
 
449
  if self.moe_topk == 1:
450
  gate_output = top1gating(logits, random_routing_dropped_token=self.random_routing_dropped_token)
451
  else:
452
+ gate_output = topkgating(logits, self.moe_topk)
453
 
454
  return gate_output
455
 
456
 
457
  class HunYuanMoE(nn.Module):
458
+ """Mixture-of-Experts block with vectorized expert execution and Straight‑Through Estimator (STE) utilities.
459
+
460
+ This implementation removes all Python‑side loops over experts. Expert parameters are **stacked** on‑the‑fly and
461
+ all experts are executed in a single batched matmul sequence, which allows efficient tensor‑parallel execution
462
+ (e.g. with DeepSpeed ZeRO‑3) while keeping the state‑dict format unchanged (each expert remains an individual
463
+ sub‑module for full compatibility with existing checkpoints)."""
464
+
465
  def __init__(self, config: HunYuanConfig, layer_idx: Optional[int] = None):
466
  super().__init__()
467
  self.config = config
468
  self.layer_idx = layer_idx
469
  self.moe_topk = config.moe_topk
470
  self.num_experts = config.num_experts
471
+
472
+ # Optional shared MLP branch (mixed MoE + dense)
473
  if config.use_mixed_mlp_moe:
474
  self.shared_mlp = HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True)
475
+
476
+ # Router
477
  self.gate = HunYuanTopKGate(config, layer_idx=layer_idx)
478
+
479
+ # Experts kept as individual sub‑modules so that load_state_dict / save_pretrained stay identical.
480
  self.experts = nn.ModuleList(
481
  [HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False) for _ in range(config.num_experts)]
482
  )
483
 
484
+ # ---------------------------------------------------------------------
485
+ # Internal helpers
486
+ # ---------------------------------------------------------------------
487
+ def _stack_weights(self):
488
+ """Return stacked (batched) expert weights for gate/up/down projections.
489
+
490
+ Shapes:
491
+ Wg : (E, I, H)
492
+ Wu : (E, I, H)
493
+ Wd : (E, H, I) – note transposed compared to nn.Linear's weight so that
494
+ torch.matmul(act, Wd.transpose(-2,‑1)) produces (E, C, H)
495
+ """
496
+ Wg = torch.stack([exp.gate_proj.weight for exp in self.experts], dim=0)
497
+ Wu = torch.stack([exp.up_proj.weight for exp in self.experts], dim=0)
498
+ Wd = torch.stack([exp.down_proj.weight for exp in self.experts], dim=0)
499
+ return Wg, Wu, Wd
500
+
501
+ # ---------------------------------------------------------------------
502
+ # Public API
503
+ # ---------------------------------------------------------------------
504
+ def forward(
505
+ self,
506
+ hidden_states: torch.Tensor,
507
+ *,
508
+ return_router_logits: bool = False,
509
+ ):
510
+ """Sparse Top‑k MoE forward pass ("y_sparse") with optional router logits.
511
+
512
+ This is the route used during both training and inference for the *sparse*
513
+ branch. No Python loops over experts are used."
514
+ """
515
  bsz, seq_len, hidden_size = hidden_states.shape
516
 
517
+ # Optional dense branch when using mixed MoE
518
  if self.config.use_mixed_mlp_moe:
519
  hidden_states_mlp = self.shared_mlp(hidden_states)
520
 
521
+ # ---------------- Routing ----------------
522
+ l_moe, combine_weights, dispatch_mask, _ = self.gate(hidden_states)
523
+
524
+ flat_input = hidden_states.reshape(-1, hidden_size) # (S, H) where S = B*T
525
+ # dispatch tokens → experts
526
+ dispatched_input = torch.einsum("sec,sm->ecm", # (E, C, H)
527
+ dispatch_mask.to(hidden_states.dtype),
528
+ flat_input)
529
+
530
+ # ---------------- Expert computation ----------------
531
+ Wg, Wu, Wd = self._stack_weights()
532
+ Wg = Wg.to(hidden_states.dtype)
533
+ Wu = Wu.to(hidden_states.dtype)
534
+ Wd = Wd.to(hidden_states.dtype)
535
+
536
+ gate_out = torch.einsum("ech,eih->eci", dispatched_input, Wg) # (E, C, I)
537
+ up_out = torch.einsum("ech,eih->eci", dispatched_input, Wu) # (E, C, I)
538
+ act_fn = self.experts[0].act_fn
539
+ interm = act_fn(gate_out) * up_out # (E, C, I)
540
+ expert_output = torch.matmul(interm, Wd.transpose(-2, -1)) # (E, C, H)
541
+
542
+ # ---------------- Combine ----------------
543
+ combined_output = torch.einsum("sec,ecm->sm", # (S, H)
544
+ combine_weights.to(hidden_states.dtype),
545
+ expert_output)
546
+ combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
547
 
548
+ if self.config.use_mixed_mlp_moe:
549
+ combined_output = hidden_states_mlp + combined_output
550
 
551
+ if return_router_logits:
552
+ router_logits = self.gate.wg(flat_input).view(bsz, seq_len, self.num_experts)
553
+ return combined_output, router_logits
554
+ return combined_output
555
 
556
+ # ---------------------------------------------------------------------
557
+ # Dense branch – outputs *all* expert activations (no routing)
558
+ # ---------------------------------------------------------------------
559
+ @torch.no_grad()
560
+ def forward_all(self, hidden_states: torch.Tensor):
561
+ """Compute every expert on every token ("dense" MoE path).
562
 
563
+ Returns:
564
+ Tensor of shape (B, T, E, H) – per‑expert hidden states.
565
+ """
566
+ bsz, seq_len, hidden_size = hidden_states.shape
567
+ flat_input = hidden_states.reshape(-1, hidden_size) # (S, H)
568
 
569
+ Wg, Wu, Wd = self._stack_weights()
570
+ Wg = Wg.to(hidden_states.dtype)
571
+ Wu = Wu.to(hidden_states.dtype)
572
+ Wd = Wd.to(hidden_states.dtype)
573
 
574
+ gate_out = torch.einsum("sh,eih->esi", flat_input, Wg) # (E, S, I)
575
+ up_out = torch.einsum("sh,eih->esi", flat_input, Wu) # (E, S, I)
576
+ act_fn = self.experts[0].act_fn
577
+ interm = act_fn(gate_out) * up_out # (E, S, I)
578
+ dense_out = torch.matmul(interm, Wd.transpose(-2, -1)) # (E, S, H)
579
 
580
+ dense_out = dense_out.permute(1, 0, 2).contiguous() # (S, E, H)
581
+ dense_out = dense_out.view(bsz, seq_len, self.num_experts, hidden_size) # (B, T, E, H)
 
 
 
 
 
 
 
 
582
 
583
+ if self.config.use_mixed_mlp_moe:
584
+ hidden_states_mlp = self.shared_mlp(hidden_states) # (B, T, H)
585
+ dense_out = dense_out + hidden_states_mlp.unsqueeze(2)
586
 
587
+ return dense_out
588
  class HunYuanAttention(nn.Module):
589
  """Multi-headed attention from 'Attention Is All You Need' paper"""
590
 
 
1127
  kv_states: Optional[Tuple[torch.Tensor]] = None,
1128
  **kwargs,
1129
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1130
+ """Modified to include Straight‑Through Estimator (STE) training logic."""
1131
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1132
  if "padding_mask" in kwargs:
1133
  warnings.warn(
1134
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use "
1135
  "`attention_mask` instead.`"
1136
  )
1137
+
1138
+ # ---------------- Self‑Attention ----------------
1139
  residual = hidden_states
 
1140
  hidden_states = self.input_layernorm(hidden_states)
1141
+
 
1142
  hidden_states, self_attn_weights, present_key_value, kv_states = self.self_attn(
1143
  hidden_states=hidden_states,
1144
  attention_mask=attention_mask,
 
1150
  **kwargs,
1151
  )
1152
  hidden_states = residual + hidden_states
1153
+
1154
+ # ---------------- MLP / MoE (+STE) ----------------
1155
  residual = hidden_states
1156
  hidden_states = self.post_attention_layernorm(hidden_states)
1157
+
1158
+ if self.training and isinstance(self.mlp, HunYuanMoE):
1159
+ # Sparse path + router logits
1160
+ y_sparse, router_logits = self.mlp(hidden_states, return_router_logits=True)
1161
+
1162
+ # Dense (all‑experts) path – memory‑efficient, no grad
1163
+ with torch.no_grad():
1164
+ y_all = self.mlp.forward_all(hidden_states) # (B, T, E, H)
1165
+
1166
+ gate = router_logits.softmax(-1).unsqueeze(-1) # (B, T, E, 1)
1167
+ y_dense = (gate * y_all).sum(-2) # (B, T, H)
1168
+
1169
+ mlp_out = y_dense + (y_sparse - y_dense).detach()
1170
+ else:
1171
+ mlp_out = self.mlp(hidden_states)
1172
+
1173
+ hidden_states = residual + mlp_out
1174
+
1175
+ # ---------------- Outputs ----------------
1176
  outputs = (hidden_states,)
1177
+
1178
  if output_attentions:
1179
  outputs += (self_attn_weights,)
1180
+
1181
  if use_cache:
1182
  outputs += (present_key_value,)
1183
+
1184
  outputs += (kv_states,)
1185
+
1186
  return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1187
  class HunYuanPreTrainedModel(PreTrainedModel):
1188
  config_class = HunYuanConfig
1189
  base_model_prefix = "model"
 
1458
  )
1459
 
1460
 
1461
+ class HunYuanForCausalLM(HunYuanPreTrainedModel):
1462
  _tied_weights_keys = ["lm_head.weight"]
1463
 
1464
  def __init__(self, config: HunYuanConfig):
 
1588
  if isinstance(past_key_values, Cache):
1589
  cache_length = past_key_values.get_seq_length()
1590
  past_length = past_key_values.seen_tokens
1591
+ max_cache_length = past_key_values.get_max_length()
1592
  else:
1593
  cache_length = past_length = past_key_values[0][0].shape[2]
1594
  max_cache_length = None
 
1766
  past_key_values=transformer_outputs.past_key_values,
1767
  hidden_states=transformer_outputs.hidden_states,
1768
  attentions=transformer_outputs.attentions,
1769
+ )