small fix
Browse files- README.md +1 -1
- modeling_lsg_bert.py +15 -3
README.md
CHANGED
|
@@ -7,7 +7,7 @@ pipeline_tag: fill-mask
|
|
| 7 |
---
|
| 8 |
|
| 9 |
# LSG model
|
| 10 |
-
**Transformers >= 4.
|
| 11 |
**This model relies on a custom modeling file, you need to add trust_remote_code=True**\
|
| 12 |
**See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
|
| 13 |
|
|
|
|
| 7 |
---
|
| 8 |
|
| 9 |
# LSG model
|
| 10 |
+
**Transformers >= 4.36.1**\
|
| 11 |
**This model relies on a custom modeling file, you need to add trust_remote_code=True**\
|
| 12 |
**See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
|
| 13 |
|
modeling_lsg_bert.py
CHANGED
|
@@ -411,8 +411,13 @@ class LSGBertEmbeddings(BertEmbeddings):
|
|
| 411 |
self.block_size = config.block_size
|
| 412 |
|
| 413 |
def forward(
|
| 414 |
-
self,
|
| 415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
if input_ids is not None:
|
| 417 |
input_shape = input_ids.size()
|
| 418 |
else:
|
|
@@ -1005,6 +1010,7 @@ class LSGBertEncoder(BertEncoder):
|
|
| 1005 |
encoder_outputs.last_hidden_state = sequence_output
|
| 1006 |
return encoder_outputs
|
| 1007 |
|
|
|
|
| 1008 |
class LSGBertPreTrainedModel(BertPreTrainedModel):
|
| 1009 |
"""
|
| 1010 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
@@ -1039,6 +1045,12 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
|
|
| 1039 |
"Cross attention is computed using full attention since it is not LSG compatible."
|
| 1040 |
)
|
| 1041 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1042 |
# Initialize weights and apply final processing
|
| 1043 |
self.post_init()
|
| 1044 |
|
|
@@ -1228,4 +1240,4 @@ try:
|
|
| 1228 |
str_to_class(value.split(".")[-1]).register_for_auto_class(key)
|
| 1229 |
except:
|
| 1230 |
warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
|
| 1231 |
-
warn("Update to transformers >= 4.
|
|
|
|
| 411 |
self.block_size = config.block_size
|
| 412 |
|
| 413 |
def forward(
|
| 414 |
+
self,
|
| 415 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 416 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 417 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 418 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 419 |
+
past_key_values_length: int = 0,
|
| 420 |
+
) -> torch.Tensor:
|
| 421 |
if input_ids is not None:
|
| 422 |
input_shape = input_ids.size()
|
| 423 |
else:
|
|
|
|
| 1010 |
encoder_outputs.last_hidden_state = sequence_output
|
| 1011 |
return encoder_outputs
|
| 1012 |
|
| 1013 |
+
|
| 1014 |
class LSGBertPreTrainedModel(BertPreTrainedModel):
|
| 1015 |
"""
|
| 1016 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
|
|
| 1045 |
"Cross attention is computed using full attention since it is not LSG compatible."
|
| 1046 |
)
|
| 1047 |
|
| 1048 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
| 1049 |
+
if self._use_flash_attention_2:
|
| 1050 |
+
logger.warning(
|
| 1051 |
+
"[WARNING flash-attention]: LSG doesnt support flash-attention currently"
|
| 1052 |
+
)
|
| 1053 |
+
|
| 1054 |
# Initialize weights and apply final processing
|
| 1055 |
self.post_init()
|
| 1056 |
|
|
|
|
| 1240 |
str_to_class(value.split(".")[-1]).register_for_auto_class(key)
|
| 1241 |
except:
|
| 1242 |
warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
|
| 1243 |
+
warn("Update to transformers >= 4.36.1 to fix.")
|