natmin322 commited on
Commit
454979d
·
1 Parent(s): 5e23c54
improve_gainlora/src/t5_gainlora_inflora.py CHANGED
@@ -681,8 +681,7 @@ class T5Attention(nn.Module):
681
  position_bias = torch.zeros(
682
  (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
683
  )
684
- if self.gradient_checkpointing and self.training:
685
- position_bias.requires_grad = True
686
  else:
687
  position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
688
 
 
681
  position_bias = torch.zeros(
682
  (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
683
  )
684
+
 
685
  else:
686
  position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
687
 
improve_gainlora/src/t5_specroute.py CHANGED
@@ -663,10 +663,8 @@ class T5Stack(T5PreTrainedModel):
663
  if self.gradient_checkpointing and self.training:
664
  def create_custom_forward(module):
665
  def custom_forward(*inputs):
666
- return tuple(module(*inputs, use_cache, output_attentions,
667
- key_attention_weights=key_attention_weights))
668
  return custom_forward
669
-
670
  # Use _gradient_checkpointing_func (set by new-format
671
  # gradient_checkpointing_enable) if available, else fallback
672
  gc_fn = getattr(self, '_gradient_checkpointing_func', None)
 
663
  if self.gradient_checkpointing and self.training:
664
  def create_custom_forward(module):
665
  def custom_forward(*inputs):
666
+ return tuple(module(*inputs, use_cache, output_attentions))
 
667
  return custom_forward
 
668
  # Use _gradient_checkpointing_func (set by new-format
669
  # gradient_checkpointing_enable) if available, else fallback
670
  gc_fn = getattr(self, '_gradient_checkpointing_func', None)