v10a
Browse files
improve_gainlora/src/t5_gainlora_inflora.py
CHANGED
|
@@ -681,8 +681,7 @@ class T5Attention(nn.Module):
|
|
| 681 |
position_bias = torch.zeros(
|
| 682 |
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
|
| 683 |
)
|
| 684 |
-
|
| 685 |
-
position_bias.requires_grad = True
|
| 686 |
else:
|
| 687 |
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
|
| 688 |
|
|
|
|
| 681 |
position_bias = torch.zeros(
|
| 682 |
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
|
| 683 |
)
|
| 684 |
+
|
|
|
|
| 685 |
else:
|
| 686 |
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
|
| 687 |
|
improve_gainlora/src/t5_specroute.py
CHANGED
|
@@ -663,10 +663,8 @@ class T5Stack(T5PreTrainedModel):
|
|
| 663 |
if self.gradient_checkpointing and self.training:
|
| 664 |
def create_custom_forward(module):
|
| 665 |
def custom_forward(*inputs):
|
| 666 |
-
return tuple(module(*inputs, use_cache, output_attentions
|
| 667 |
-
key_attention_weights=key_attention_weights))
|
| 668 |
return custom_forward
|
| 669 |
-
|
| 670 |
# Use _gradient_checkpointing_func (set by new-format
|
| 671 |
# gradient_checkpointing_enable) if available, else fallback
|
| 672 |
gc_fn = getattr(self, '_gradient_checkpointing_func', None)
|
|
|
|
| 663 |
if self.gradient_checkpointing and self.training:
|
| 664 |
def create_custom_forward(module):
|
| 665 |
def custom_forward(*inputs):
|
| 666 |
+
return tuple(module(*inputs, use_cache, output_attentions))
|
|
|
|
| 667 |
return custom_forward
|
|
|
|
| 668 |
# Use _gradient_checkpointing_func (set by new-format
|
| 669 |
# gradient_checkpointing_enable) if available, else fallback
|
| 670 |
gc_fn = getattr(self, '_gradient_checkpointing_func', None)
|