maxoul commited on
Commit
5afcf9e
·
verified ·
1 Parent(s): f0c8335

Create splade.py

Browse files
Files changed (1) hide show
  1. splade.py +97 -0
splade.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from transformers import (
3
+ PretrainedConfig,
4
+ PreTrainedModel,
5
+ AutoConfig,
6
+ )
7
+ from huggingface_hub import snapshot_download
8
+ from typing import Optional
9
+ from transformers.utils import is_flash_attn_2_available
10
+ from .utils import (
11
+ get_decoder_model,
12
+ prepare_tokenizer,
13
+ splade_max,
14
+ similarity,
15
+ encode,
16
+ )
17
+
18
+
19
+ class SpladeConfig(PretrainedConfig):
20
+ model_type = "splade"
21
+
22
+ def __init__(
23
+ self,
24
+ model_name_or_path: str = "meta-llama/Llama-3.1-8B",
25
+ attn_implementation: str = "flash_attention_2",
26
+ bidirectional: bool = True, # only for decoder models
27
+ padding_side: str = "right",
28
+ **kwargs,
29
+ ):
30
+ super().__init__(**kwargs)
31
+ self.model_name_or_path = model_name_or_path
32
+ self.attn_implementation = attn_implementation
33
+ self.bidirectional = bidirectional
34
+ self.padding_side = padding_side
35
+
36
+
37
+ class Splade(PreTrainedModel):
38
+ config_class = SpladeConfig
39
+
40
+ # methods for MTEB's interface
41
+ similarity = similarity
42
+ encode = encode
43
+
44
+ def __init__(self, config):
45
+ super().__init__(config)
46
+ self.name = "splade"
47
+ base_cfg = AutoConfig.from_pretrained(
48
+ config.model_name_or_path,
49
+ attn_implementation=config.attn_implementation,
50
+ torch_dtype="auto",
51
+ )
52
+ self.tokenizer = prepare_tokenizer(
53
+ config.model_name_or_path, padding_side=config.padding_side
54
+ )
55
+ if is_flash_attn_2_available():
56
+ config.attn_implementation = "flash_attention_2"
57
+ else:
58
+ config.attn_implementation = "sdpa"
59
+ self.model = get_decoder_model(
60
+ model_name_or_path=config.model_name_or_path,
61
+ attn_implementation=config.attn_implementation,
62
+ bidirectional=getattr(config, "bidirectional", False),
63
+ base_cfg=base_cfg,
64
+ )
65
+
66
+ def save_pretrained(self, save_directory, *args, **kwargs):
67
+ self.model.save_pretrained(os.path.join(save_directory, "lora"))
68
+ self.config.save_pretrained(save_directory)
69
+
70
+ @classmethod
71
+ def from_pretrained(cls, model_name_or_path, *args, **kwargs):
72
+ config = SpladeConfig.from_pretrained(model_name_or_path)
73
+ model = cls(config)
74
+ local_dir = snapshot_download(model_name_or_path)
75
+ adapter_path = os.path.join(local_dir, "lora")
76
+ model.model.load_adapter(adapter_path)
77
+ model.reverse_voc = {v: k for k, v in model.tokenizer.vocab.items()}
78
+ return model
79
+
80
+ def forward(self, **tokens):
81
+ output = self.model(**tokens)
82
+ splade_reps, _ = splade_max(output.logits, tokens["attention_mask"])
83
+ return (splade_reps,)
84
+
85
+ def get_width(self):
86
+ return self.model.config.vocab_size
87
+
88
+ def create_batch_dict(self, input_texts, max_length):
89
+ return self.tokenizer(
90
+ input_texts,
91
+ add_special_tokens=True,
92
+ padding="longest",
93
+ truncation=True,
94
+ max_length=max_length,
95
+ return_attention_mask=True,
96
+ return_tensors="pt",
97
+ )