maxoul commited on
Commit
a1702e0
·
verified ·
1 Parent(s): 03203c6

Update splade.py

Browse files
Files changed (1) hide show
  1. splade.py +7 -1
splade.py CHANGED
@@ -45,20 +45,26 @@ class Splade(PreTrainedModel):
45
  def __init__(self, config, weights_path=None):
46
  super().__init__(config)
47
  self.name = "splade"
 
48
  base_cfg = AutoConfig.from_pretrained(
49
  config.model_name_or_path,
50
  attn_implementation=config.attn_implementation,
51
  torch_dtype="auto",
52
  )
 
53
  self.tokenizer = prepare_tokenizer(
54
  config.model_name_or_path, padding_side=config.padding_side
55
  )
 
56
  if is_flash_attn_2_available():
57
  config.attn_implementation = "flash_attention_2"
58
  else:
59
  config.attn_implementation = "sdpa"
 
 
 
60
  self.model = get_decoder_model(
61
- model_name_or_path=config.model_name_or_path,
62
  attn_implementation=config.attn_implementation,
63
  bidirectional=getattr(config, "bidirectional", False),
64
  base_cfg=base_cfg,
 
45
  def __init__(self, config, weights_path=None):
46
  super().__init__(config)
47
  self.name = "splade"
48
+
49
  base_cfg = AutoConfig.from_pretrained(
50
  config.model_name_or_path,
51
  attn_implementation=config.attn_implementation,
52
  torch_dtype="auto",
53
  )
54
+
55
  self.tokenizer = prepare_tokenizer(
56
  config.model_name_or_path, padding_side=config.padding_side
57
  )
58
+
59
  if is_flash_attn_2_available():
60
  config.attn_implementation = "flash_attention_2"
61
  else:
62
  config.attn_implementation = "sdpa"
63
+
64
+ source = weights_path or config.model_name_or_path
65
+
66
  self.model = get_decoder_model(
67
+ model_name_or_path=source,
68
  attn_implementation=config.attn_implementation,
69
  bidirectional=getattr(config, "bidirectional", False),
70
  base_cfg=base_cfg,