Update splade.py
Browse files
splade.py
CHANGED
|
@@ -45,20 +45,26 @@ class Splade(PreTrainedModel):
|
|
| 45 |
def __init__(self, config, weights_path=None):
|
| 46 |
super().__init__(config)
|
| 47 |
self.name = "splade"
|
|
|
|
| 48 |
base_cfg = AutoConfig.from_pretrained(
|
| 49 |
config.model_name_or_path,
|
| 50 |
attn_implementation=config.attn_implementation,
|
| 51 |
torch_dtype="auto",
|
| 52 |
)
|
|
|
|
| 53 |
self.tokenizer = prepare_tokenizer(
|
| 54 |
config.model_name_or_path, padding_side=config.padding_side
|
| 55 |
)
|
|
|
|
| 56 |
if is_flash_attn_2_available():
|
| 57 |
config.attn_implementation = "flash_attention_2"
|
| 58 |
else:
|
| 59 |
config.attn_implementation = "sdpa"
|
|
|
|
|
|
|
|
|
|
| 60 |
self.model = get_decoder_model(
|
| 61 |
-
model_name_or_path=
|
| 62 |
attn_implementation=config.attn_implementation,
|
| 63 |
bidirectional=getattr(config, "bidirectional", False),
|
| 64 |
base_cfg=base_cfg,
|
|
|
|
| 45 |
def __init__(self, config, weights_path=None):
|
| 46 |
super().__init__(config)
|
| 47 |
self.name = "splade"
|
| 48 |
+
|
| 49 |
base_cfg = AutoConfig.from_pretrained(
|
| 50 |
config.model_name_or_path,
|
| 51 |
attn_implementation=config.attn_implementation,
|
| 52 |
torch_dtype="auto",
|
| 53 |
)
|
| 54 |
+
|
| 55 |
self.tokenizer = prepare_tokenizer(
|
| 56 |
config.model_name_or_path, padding_side=config.padding_side
|
| 57 |
)
|
| 58 |
+
|
| 59 |
if is_flash_attn_2_available():
|
| 60 |
config.attn_implementation = "flash_attention_2"
|
| 61 |
else:
|
| 62 |
config.attn_implementation = "sdpa"
|
| 63 |
+
|
| 64 |
+
source = weights_path or config.model_name_or_path
|
| 65 |
+
|
| 66 |
self.model = get_decoder_model(
|
| 67 |
+
model_name_or_path=source,
|
| 68 |
attn_implementation=config.attn_implementation,
|
| 69 |
bidirectional=getattr(config, "bidirectional", False),
|
| 70 |
base_cfg=base_cfg,
|