natmin322 commited on
Commit
ddb0466
·
1 Parent(s): 2200936
improve_gainlora/src/t5_specroute.py CHANGED
@@ -663,8 +663,13 @@ class T5Stack(T5PreTrainedModel):
663
  if self.gradient_checkpointing and self.training:
664
  def create_custom_forward(module):
665
  def custom_forward(*inputs):
666
- return tuple(module(*inputs, use_cache, output_attentions,
667
- key_attention_weights=key_attention_weights))
 
 
 
 
 
668
  return custom_forward
669
  # Use _gradient_checkpointing_func (set by new-format
670
  # gradient_checkpointing_enable) if available, else fallback
@@ -681,6 +686,7 @@ class T5Stack(T5PreTrainedModel):
681
  layer_head_mask,
682
  cross_attn_layer_head_mask,
683
  None,
 
684
  )
685
  else:
686
  layer_outputs = checkpoint(
@@ -694,7 +700,8 @@ class T5Stack(T5PreTrainedModel):
694
  layer_head_mask,
695
  cross_attn_layer_head_mask,
696
  None,
697
- use_reentrant=False,
 
698
  )
699
  else:
700
  layer_outputs = layer_module(
 
663
  if self.gradient_checkpointing and self.training:
664
  def create_custom_forward(module):
665
  def custom_forward(*inputs):
666
+ # inputs = (hidden_states, attn_mask, pos_bias, enc_hs, enc_attn_mask, enc_dec_pos_bias, layer_head_mask, cross_attn_head_mask, None, key_attn_weights)
667
+ return tuple(module(inputs[0], attention_mask=inputs[1], position_bias=inputs[2],
668
+ encoder_hidden_states=inputs[3], encoder_attention_mask=inputs[4],
669
+ encoder_decoder_position_bias=inputs[5], layer_head_mask=inputs[6],
670
+ cross_attn_layer_head_mask=inputs[7], past_key_value=inputs[8],
671
+ use_cache=use_cache, output_attentions=output_attentions,
672
+ key_attention_weights=inputs[9]))
673
  return custom_forward
674
  # Use _gradient_checkpointing_func (set by new-format
675
  # gradient_checkpointing_enable) if available, else fallback
 
686
  layer_head_mask,
687
  cross_attn_layer_head_mask,
688
  None,
689
+ key_attention_weights,
690
  )
691
  else:
692
  layer_outputs = checkpoint(
 
700
  layer_head_mask,
701
  cross_attn_layer_head_mask,
702
  None,
703
+ key_attention_weights,
704
+ use_reentrant=True,
705
  )
706
  else:
707
  layer_outputs = layer_module(