Update modeling_opt.py
Browse files- modeling_opt.py +4 -3
modeling_opt.py
CHANGED
|
@@ -38,6 +38,7 @@ from transformers.utils import (
|
|
| 38 |
add_code_sample_docstrings,
|
| 39 |
add_start_docstrings,
|
| 40 |
add_start_docstrings_to_model_forward,
|
|
|
|
| 41 |
is_flash_attn_2_available,
|
| 42 |
is_flash_attn_greater_or_equal_2_10,
|
| 43 |
logging,
|
|
@@ -725,10 +726,10 @@ class OPTDecoderLayer(nn.Module):
|
|
| 725 |
super().__init__()
|
| 726 |
self.embed_dim = config.hidden_size
|
| 727 |
|
| 728 |
-
self.self_attn = OPT_ATTENTION_CLASSES[config.
|
| 729 |
config=config, is_decoder=True)
|
| 730 |
print(self.self_attn)
|
| 731 |
-
print(config.
|
| 732 |
self.do_layer_norm_before = config.do_layer_norm_before
|
| 733 |
self.dropout = config.dropout
|
| 734 |
self.activation_fn = ACT2FN[config.activation_function]
|
|
@@ -970,7 +971,7 @@ class OPTDecoder(OPTPreTrainedModel):
|
|
| 970 |
|
| 971 |
self.layers = nn.ModuleList(
|
| 972 |
[OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 973 |
-
self._use_flash_attention_2 = config.
|
| 974 |
|
| 975 |
self.gradient_checkpointing = False
|
| 976 |
# Initialize weights and apply final processing
|
|
|
|
| 38 |
add_code_sample_docstrings,
|
| 39 |
add_start_docstrings,
|
| 40 |
add_start_docstrings_to_model_forward,
|
| 41 |
+
|
| 42 |
is_flash_attn_2_available,
|
| 43 |
is_flash_attn_greater_or_equal_2_10,
|
| 44 |
logging,
|
|
|
|
| 726 |
super().__init__()
|
| 727 |
self.embed_dim = config.hidden_size
|
| 728 |
|
| 729 |
+
self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](
|
| 730 |
config=config, is_decoder=True)
|
| 731 |
print(self.self_attn)
|
| 732 |
+
print(config._attn_implementation)
|
| 733 |
self.do_layer_norm_before = config.do_layer_norm_before
|
| 734 |
self.dropout = config.dropout
|
| 735 |
self.activation_fn = ACT2FN[config.activation_function]
|
|
|
|
| 971 |
|
| 972 |
self.layers = nn.ModuleList(
|
| 973 |
[OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 974 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
| 975 |
|
| 976 |
self.gradient_checkpointing = False
|
| 977 |
# Initialize weights and apply final processing
|