DChak2000 commited on
Commit
ec0dce9
·
verified ·
1 Parent(s): ae5f485

LoRA fine-tuned NV-Embed-v2 for Prolog–NL trace alignment

Browse files
1_Pooling/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 4096,
3
+ "pooling_mode_cls_token": false,
4
+ "pooling_mode_mean_tokens": true,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false,
7
+ "pooling_mode_weightedmean_tokens": false,
8
+ "pooling_mode_lasttoken": false,
9
+ "include_prompt": false
10
+ }
README.md ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: sentence-transformers
3
+ pipeline_tag: sentence-similarity
4
+ tags:
5
+ - sentence-transformers
6
+ - feature-extraction
7
+ - sentence-similarity
8
+
9
+ ---
10
+
11
+ # DChak2000/nv-embed-v2-trace-align
12
+
13
+ This is a [sentence-transformers](https://www.SBERT.net) model: It maps sentences & paragraphs to a 4096 dimensional dense vector space and can be used for tasks like clustering or semantic search.
14
+
15
+ <!--- Describe your model here -->
16
+
17
+ ## Usage (Sentence-Transformers)
18
+
19
+ Using this model becomes easy when you have [sentence-transformers](https://www.SBERT.net) installed:
20
+
21
+ ```
22
+ pip install -U sentence-transformers
23
+ ```
24
+
25
+ Then you can use the model like this:
26
+
27
+ ```python
28
+ from sentence_transformers import SentenceTransformer
29
+ sentences = ["This is an example sentence", "Each sentence is converted"]
30
+
31
+ model = SentenceTransformer('DChak2000/nv-embed-v2-trace-align')
32
+ embeddings = model.encode(sentences)
33
+ print(embeddings)
34
+ ```
35
+
36
+
37
+
38
+ ## Evaluation Results
39
+
40
+ <!--- Describe how your model was evaluated -->
41
+
42
+ For an automated evaluation of this model, see the *Sentence Embeddings Benchmark*: [https://seb.sbert.net](https://seb.sbert.net?model_name=DChak2000/nv-embed-v2-trace-align)
43
+
44
+
45
+ ## Training
46
+ The model was trained with the parameters:
47
+
48
+ **DataLoader**:
49
+
50
+ `torch.utils.data.dataloader.DataLoader` of length 1459 with parameters:
51
+ ```
52
+ {'batch_size': 2, 'sampler': 'torch.utils.data.sampler.RandomSampler', 'batch_sampler': 'torch.utils.data.sampler.BatchSampler'}
53
+ ```
54
+
55
+ **Loss**:
56
+
57
+ `sentence_transformers.losses.CosineSimilarityLoss.CosineSimilarityLoss`
58
+
59
+ Parameters of the fit()-Method:
60
+ ```
61
+ {
62
+ "epochs": 3,
63
+ "evaluation_steps": 0,
64
+ "evaluator": "NoneType",
65
+ "max_grad_norm": 1,
66
+ "optimizer_class": "<class 'torch.optim.adamw.AdamW'>",
67
+ "optimizer_params": {
68
+ "lr": 2e-05
69
+ },
70
+ "scheduler": "WarmupLinear",
71
+ "steps_per_epoch": null,
72
+ "warmup_steps": 200,
73
+ "weight_decay": 0.01
74
+ }
75
+ ```
76
+
77
+
78
+ ## Full Model Architecture
79
+ ```
80
+ SentenceTransformer(
81
+ (0): Transformer({'max_seq_length': 32768, 'do_lower_case': False}) with Transformer model: NVEmbedModel
82
+ (1): Pooling({'word_embedding_dimension': 4096, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': False})
83
+ (2): Normalize()
84
+ )
85
+ ```
86
+
87
+ ## Citing & Authors
88
+
89
+ <!--- Describe where people can find more information -->
adapter_config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "nvidia/NV-Embed-v2",
5
+ "bias": "none",
6
+ "fan_in_fan_out": false,
7
+ "inference_mode": false,
8
+ "init_lora_weights": true,
9
+ "layer_replication": null,
10
+ "layers_pattern": null,
11
+ "layers_to_transform": null,
12
+ "loftq_config": {},
13
+ "lora_alpha": 32,
14
+ "lora_dropout": 0.05,
15
+ "megatron_config": null,
16
+ "megatron_core": "megatron.core",
17
+ "modules_to_save": null,
18
+ "peft_type": "LORA",
19
+ "r": 16,
20
+ "rank_pattern": {},
21
+ "revision": null,
22
+ "target_modules": [
23
+ "k_proj",
24
+ "q_proj",
25
+ "v_proj",
26
+ "o_proj"
27
+ ],
28
+ "task_type": "FEATURE_EXTRACTION",
29
+ "use_dora": false,
30
+ "use_rslora": false
31
+ }
adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e338285af8ea88f046a9ebade47bf0377e1050390de26ac94aaac0a7bead4744
3
+ size 27300104
config_sentence_transformers.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "2.7.0",
4
+ "transformers": "4.37.2",
5
+ "pytorch": "2.2.0+cu121"
6
+ },
7
+ "prompts": {},
8
+ "default_prompt_name": null
9
+ }
configuration_nvembed.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Literal
3
+ from transformers import AutoConfig
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.models.auto import CONFIG_MAPPING
6
+ from transformers.models.mistral import MistralConfig
7
+
8
+ NVEMBED_TYPE = "nvembed"
9
+ LATENT_ATTENTION_TYPE = "latent_attention"
10
+ BIDIR_MISTRAL_TYPE = "bidir_mistral"
11
+
12
+ class NVEmbedConfig(PretrainedConfig):
13
+ model_type = "nvembed"
14
+ is_composition = False
15
+
16
+ def __init__(
17
+ self,
18
+ latent_attention_config=None,
19
+ text_config=None,
20
+ padding_side: Literal["right", "left"]="right",
21
+ add_pad_token: bool=True,
22
+ is_mask_instruction: bool = True,
23
+ add_eos: bool=True,
24
+ mask_type: str="b",
25
+ **kwargs,
26
+ ):
27
+ if isinstance(latent_attention_config, dict):
28
+ latent_attention_config["model_type"] = (
29
+ latent_attention_config["model_type"] if "model_type" in latent_attention_config else LATENT_ATTENTION_TYPE
30
+ )
31
+ latent_attention_config = CONFIG_MAPPING[latent_attention_config["model_type"]](**latent_attention_config)
32
+ elif latent_attention_config is None:
33
+ latent_attention_config = CONFIG_MAPPING[LATENT_ATTENTION_TYPE]()
34
+
35
+ self.latent_attention_config = latent_attention_config
36
+
37
+ if isinstance(text_config, dict):
38
+ text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
39
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
40
+ elif text_config is None:
41
+ text_config = None
42
+
43
+ self.text_config = text_config
44
+ self.padding_side = padding_side
45
+ self.is_mask_instruction = is_mask_instruction
46
+ self.add_pad_token = add_pad_token
47
+ self.add_eos = add_eos
48
+ self.mask_type = mask_type
49
+ if "hidden_size" in kwargs:
50
+ self.hidden_size = kwargs["hidden_size"]
51
+ else:
52
+ self.hidden_size = 4096
53
+
54
+ super().__init__(**kwargs)
55
+
56
+
57
+ class LatentAttentionConfig(PretrainedConfig):
58
+ model_type = LATENT_ATTENTION_TYPE
59
+ is_composition = False
60
+ _name_or_path = "latent_attention"
61
+
62
+ def __init__(
63
+ self,
64
+ num_latents_value: int=512,
65
+ num_cross_heads: int=8,
66
+ output_normalize: bool=True,
67
+ hidden_dim: int=4096,
68
+ latent_dim: int=4096,
69
+ cross_dim_head: int=4096,
70
+ **kwargs,
71
+ ):
72
+ self.num_latents_value = num_latents_value
73
+ self.num_cross_heads = num_cross_heads
74
+ self.output_normalize = output_normalize
75
+ self.hidden_dim = hidden_dim
76
+ self.latent_dim = latent_dim
77
+ self.cross_dim_head = cross_dim_head
78
+
79
+ super().__init__(**kwargs)
80
+
81
+
82
+ class BidirectionalMistralConfig(MistralConfig):
83
+ model_type = BIDIR_MISTRAL_TYPE
84
+ keys_to_ignore_at_inference = ["past_key_values"]
85
+
86
+ AutoConfig.register(NVEMBED_TYPE, NVEmbedConfig)
87
+ AutoConfig.register(LATENT_ATTENTION_TYPE, LatentAttentionConfig)
88
+ AutoConfig.register(BIDIR_MISTRAL_TYPE, BidirectionalMistralConfig)
89
+
90
+ NVEmbedConfig.register_for_auto_class()
91
+ LatentAttentionConfig.register_for_auto_class()
92
+ BidirectionalMistralConfig.register_for_auto_class()
modeling_nvembed.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union, Dict, Mapping, Optional, Tuple, TypedDict
2
+ import torch
3
+ import os
4
+ import json
5
+ import numpy as np
6
+ from functools import partial
7
+ from contextlib import nullcontext
8
+ from transformers import AutoModel, PreTrainedTokenizerFast, BatchEncoding, DataCollatorWithPadding
9
+ from transformers.modeling_utils import PreTrainedModel
10
+ from transformers.models.auto import AutoTokenizer
11
+ from transformers.modeling_outputs import BaseModelOutputWithPast
12
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
13
+ from transformers import MistralModel, MistralConfig
14
+ from transformers.cache_utils import Cache, DynamicCache
15
+ from transformers.utils import (
16
+ add_start_docstrings_to_model_forward,
17
+ logging,
18
+ )
19
+ from einops import rearrange, repeat
20
+ from tqdm.auto import tqdm
21
+ from datasets import Dataset
22
+ from torch.utils.data import DataLoader
23
+ from .configuration_nvembed import NVEmbedConfig, LatentAttentionConfig, BidirectionalMistralConfig
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ class NVEmbedFeatures(TypedDict):
28
+ input_dict: torch.Tensor
29
+ attention_mask: torch.Tensor
30
+ pool_mask: torch.Tensor
31
+
32
+ class BidirectionalMistralModel(MistralModel):
33
+ config_class = BidirectionalMistralConfig
34
+
35
+ def __init__(self, config: MistralConfig):
36
+ super().__init__(config)
37
+ for layer in self.layers:
38
+ layer.self_attn.is_causal = False
39
+ self._attn_implementation = "eager"
40
+
41
+ def forward(
42
+ self,
43
+ input_ids: torch.LongTensor = None,
44
+ attention_mask: Optional[torch.Tensor] = None,
45
+ position_ids: Optional[torch.LongTensor] = None,
46
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
47
+ inputs_embeds: Optional[torch.FloatTensor] = None,
48
+ use_cache: Optional[bool] = None,
49
+ output_attentions: Optional[bool] = None,
50
+ output_hidden_states: Optional[bool] = None,
51
+ return_dict: Optional[bool] = None,
52
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
53
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
54
+ output_hidden_states = (
55
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
56
+ )
57
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
58
+
59
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
60
+
61
+ # retrieve input_ids and inputs_embeds
62
+ if input_ids is not None and inputs_embeds is not None:
63
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
64
+ elif input_ids is not None:
65
+ batch_size, seq_length = input_ids.shape
66
+ elif inputs_embeds is not None:
67
+ batch_size, seq_length, _ = inputs_embeds.shape
68
+ else:
69
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
70
+
71
+ if self.gradient_checkpointing and self.training:
72
+ if use_cache:
73
+ logger.warning_once(
74
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
75
+ )
76
+ use_cache = False
77
+
78
+ past_key_values_length = 0
79
+
80
+ if use_cache:
81
+ use_legacy_cache = not isinstance(past_key_values, Cache)
82
+ if use_legacy_cache:
83
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
84
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
85
+
86
+ if position_ids is None:
87
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
88
+ position_ids = torch.arange(
89
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
90
+ )
91
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
92
+ else:
93
+ position_ids = position_ids.view(-1, seq_length).long()
94
+
95
+ if inputs_embeds is None:
96
+ inputs_embeds = self.embed_tokens(input_ids)
97
+
98
+ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
99
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
100
+ if is_padding_right:
101
+ raise ValueError(
102
+ "You are attempting to perform batched generation with padding_side='right'"
103
+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
104
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
105
+ )
106
+
107
+ if self._attn_implementation == "flash_attention_2":
108
+ # 2d mask is passed through the layers
109
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
110
+ elif self._attn_implementation == "sdpa" and not output_attentions:
111
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
112
+ # the manual implementation that requires a 4D causal mask in all cases.
113
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(
114
+ attention_mask, inputs_embeds.dtype
115
+ )
116
+ else:
117
+ # 4d mask is passed through the layers
118
+ attention_mask = _prepare_4d_attention_mask(
119
+ attention_mask, inputs_embeds.dtype,
120
+ )
121
+
122
+ hidden_states = inputs_embeds
123
+
124
+ # decoder layers
125
+ all_hidden_states = () if output_hidden_states else None
126
+ all_self_attns = () if output_attentions else None
127
+ next_decoder_cache = None
128
+
129
+ for decoder_layer in self.layers:
130
+ if output_hidden_states:
131
+ all_hidden_states += (hidden_states,)
132
+
133
+ if self.gradient_checkpointing and self.training:
134
+ layer_outputs = self._gradient_checkpointing_func(
135
+ decoder_layer.__call__,
136
+ hidden_states,
137
+ attention_mask,
138
+ position_ids,
139
+ past_key_values,
140
+ output_attentions,
141
+ use_cache,
142
+ )
143
+ else:
144
+ layer_outputs = decoder_layer(
145
+ hidden_states,
146
+ attention_mask=attention_mask,
147
+ position_ids=position_ids,
148
+ past_key_value=past_key_values,
149
+ output_attentions=output_attentions,
150
+ use_cache=use_cache,
151
+ )
152
+
153
+ hidden_states = layer_outputs[0]
154
+
155
+ if use_cache:
156
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
157
+
158
+ if output_attentions:
159
+ all_self_attns += (layer_outputs[1],)
160
+
161
+ hidden_states = self.norm(hidden_states)
162
+
163
+ # add hidden states from the last decoder layer
164
+ if output_hidden_states:
165
+ all_hidden_states += (hidden_states,)
166
+
167
+ next_cache = None
168
+ if use_cache:
169
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
170
+
171
+ if not return_dict:
172
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
173
+ return BaseModelOutputWithPast(
174
+ last_hidden_state=hidden_states,
175
+ past_key_values=next_cache,
176
+ hidden_states=all_hidden_states,
177
+ attentions=all_self_attns,
178
+ )
179
+
180
+ def _move_to_device(maybe_tensor, device: torch.device):
181
+ if torch.is_tensor(maybe_tensor):
182
+ return maybe_tensor.to(device, non_blocking=device.type == "cuda")
183
+ elif isinstance(maybe_tensor, dict):
184
+ return {key: _move_to_device(value, device) for key, value in maybe_tensor.items()}
185
+ elif isinstance(maybe_tensor, list):
186
+ return [_move_to_device(x, device) for x in maybe_tensor]
187
+ elif isinstance(maybe_tensor, tuple):
188
+ return tuple([_move_to_device(x, device) for x in maybe_tensor])
189
+ elif isinstance(maybe_tensor, Mapping):
190
+ return type(maybe_tensor)({k: _move_to_device(v, device) for k, v in maybe_tensor.items()})
191
+ else:
192
+ return maybe_tensor
193
+
194
+ def move_to_device(sample, device: torch.device):
195
+ if device.type == "cpu":
196
+ return sample
197
+
198
+ if len(sample) == 0:
199
+ return {}
200
+ return _move_to_device(sample, device)
201
+
202
+
203
+ def input_transform_func(
204
+ tokenizer: PreTrainedTokenizerFast,
205
+ examples: Dict[str, List],
206
+ always_add_eos: bool,
207
+ max_length: int,
208
+ instruction: str,
209
+ ) -> BatchEncoding:
210
+ if always_add_eos:
211
+ examples['input_texts'] = [instruction + input_example + tokenizer.eos_token for input_example in examples['input_texts']]
212
+ batch_dict = tokenizer(
213
+ examples['input_texts'],
214
+ max_length=max_length,
215
+ padding=True,
216
+ return_token_type_ids=False,
217
+ return_tensors="pt",
218
+ truncation=True)
219
+ return batch_dict
220
+
221
+
222
+ class PreNorm(torch.nn.Module):
223
+ def __init__(self, dim, fn, context_dim = None):
224
+ super().__init__()
225
+ self.fn = fn
226
+ self.norm = torch.nn.LayerNorm(dim)
227
+ self.norm_context = torch.nn.LayerNorm(context_dim) if exists(context_dim) else None
228
+
229
+ def forward(self, x, **kwargs):
230
+ x = self.norm(x)
231
+ if exists(self.norm_context):
232
+ context = kwargs['context']
233
+ normed_context = self.norm_context(context)
234
+ kwargs.update(context = normed_context)
235
+ return self.fn(x, **kwargs)
236
+
237
+ class GEGLU(torch.nn.Module):
238
+ def forward(self, x):
239
+ x, gates = x.chunk(2, dim = -1)
240
+ return x * torch.nn.functional.gelu(gates)
241
+
242
+ class FeedForward(torch.nn.Module):
243
+ def __init__(self, dim, mult = 4):
244
+ super().__init__()
245
+ self.net = torch.nn.Sequential(torch.nn.Linear(dim, dim * mult * 2),
246
+ GEGLU(),
247
+ torch.nn.Linear(dim * mult, dim))
248
+
249
+ def forward(self, x):
250
+ return self.net(x)
251
+
252
+ def exists(val):
253
+ return val is not None
254
+
255
+ def default(val, d):
256
+ return val if exists(val) else d
257
+
258
+
259
+ class Attention(torch.nn.Module):
260
+ def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64):
261
+ super().__init__()
262
+ inner_dim = dim_head * heads
263
+ context_dim = default(context_dim, query_dim)
264
+ self.scale = dim_head ** -0.5
265
+ self.heads = heads
266
+
267
+ self.to_q = torch.nn.Linear(query_dim, inner_dim, bias = False)
268
+ self.to_kv = torch.nn.Linear(context_dim, inner_dim * 2, bias = False)
269
+ self.to_out = torch.nn.Linear(inner_dim, query_dim, bias = False)
270
+
271
+ def forward(self, x, context = None, mask = None):
272
+ h = self.heads
273
+ q = self.to_q(x)
274
+ context = default(context, x)
275
+ k, v = self.to_kv(context).chunk(2, dim = -1)
276
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
277
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True):
278
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
279
+ out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
280
+ return self.to_out(out)
281
+
282
+
283
+ class LatentAttentionModel(PreTrainedModel):
284
+ config_class = LatentAttentionConfig
285
+
286
+ def __init__(self, config: LatentAttentionConfig):
287
+ super().__init__(config)
288
+ ## cross-attention block
289
+ num_latents, latent_dim, cross_heads, cross_dim_head = config.num_latents_value, config.latent_dim, config.num_cross_heads, config.cross_dim_head
290
+ dim = config.hidden_dim
291
+ # init latent_attention and latents
292
+ self.cross_attend_blocks = torch.nn.ModuleList([
293
+ PreNorm(latent_dim, Attention(latent_dim, dim, heads = cross_heads, dim_head = cross_dim_head),
294
+ context_dim = dim),
295
+ PreNorm(latent_dim, FeedForward(latent_dim)),
296
+ ])
297
+ self.output_normalize = config.output_normalize
298
+ self.register_parameter("latents", torch.nn.Parameter(torch.randn(num_latents, latent_dim)))
299
+
300
+ def forward(self, hiddens, attention_mask: torch.Tensor=None):
301
+ ## cross-attention block
302
+ cross_attn, cross_ff = self.cross_attend_blocks
303
+ b, *_, device = *hiddens.shape, hiddens.device
304
+ x = repeat(self.latents, 'n d -> b n d', b = b)
305
+ hiddens = cross_attn(hiddens, context = x, mask = None) + hiddens
306
+ hiddens = cross_ff(hiddens) + hiddens
307
+ if attention_mask !=None:
308
+ s = torch.sum(hiddens * attention_mask.unsqueeze(-1).float(), dim=1)
309
+ d = attention_mask.sum(dim=1, keepdim=True).float()
310
+ hiddens = s / d
311
+ if self.output_normalize:
312
+ hiddens = torch.nn.functional.normalize(hiddens, p=2, dim=-1)
313
+ return hiddens
314
+
315
+ class NVEmbedModel(PreTrainedModel):
316
+ config_class = NVEmbedConfig
317
+ _no_split_modules = ["MistralDecoderLayer", "LatentAttentionModel"]
318
+
319
+ def __init__(self, config: NVEmbedConfig):
320
+ super().__init__(config)
321
+ self.latent_attention_model = AutoModel.from_config(config.latent_attention_config)
322
+ self.embedding_model = AutoModel.from_config(
323
+ config.text_config,
324
+ ) if config.text_config is not None else None
325
+ self.tokenizer = AutoTokenizer.from_pretrained(config.text_config._name_or_path) if config.text_config is not None else None
326
+ self.padding_side = config.padding_side
327
+ self.is_mask_instruction = config.is_mask_instruction
328
+ self.add_eos = config.add_eos
329
+ self.mask_type = config.mask_type
330
+ if config.add_pad_token and self.tokenizer is not None:
331
+ self.add_pad_token()
332
+
333
+ def add_pad_token(self):
334
+ self.tokenizer.pad_token = self.tokenizer.eos_token
335
+ self.tokenizer.padding_side = self.padding_side
336
+
337
+ def prepare_kwargs_from_batch(self, batch_dict: dict, instruction_lens: int, device: torch.device):
338
+ batch_dict = move_to_device(batch_dict, device)
339
+ attention_mask = batch_dict['attention_mask'].clone() if 'attention_mask' in batch_dict else None
340
+ if (attention_mask is not None and
341
+ self.padding_side == "right" and
342
+ self.is_mask_instruction == True and
343
+ instruction_lens > 0):
344
+ # Mask out the instruction tokens for mean-pooling
345
+ attention_mask[:, :instruction_lens] = 0
346
+ features: NVEmbedFeatures = {
347
+ 'input_ids': torch.tensor(batch_dict.get('input_ids').to(batch_dict.get('input_ids')).long()),
348
+ 'attention_mask': batch_dict['attention_mask'],
349
+ 'pool_mask': attention_mask,
350
+ }
351
+ return features
352
+
353
+ @torch.no_grad()
354
+ def _do_encode(self,
355
+ prompts: List[str],
356
+ batch_size: int=1,
357
+ instruction: str="",
358
+ max_length: int=4096,
359
+ num_workers: int=32,
360
+ **kwargs
361
+ ) -> Union[np.ndarray, torch.FloatTensor]:
362
+ dataset: Dataset = Dataset.from_dict({'input_texts': prompts})
363
+ dataset.set_transform(partial(input_transform_func,
364
+ self.tokenizer,
365
+ always_add_eos=True,
366
+ max_length=max_length,
367
+ instruction=instruction))
368
+
369
+ data_collator = DataCollatorWithPadding(self.tokenizer)
370
+ data_loader = DataLoader(
371
+ dataset,
372
+ batch_size=batch_size,
373
+ shuffle=False,
374
+ drop_last=False,
375
+ num_workers=num_workers,
376
+ collate_fn=data_collator,
377
+ pin_memory=True)
378
+
379
+ if self.padding_side == "right" and self.is_mask_instruction == True and len(instruction) > 0:
380
+ instruction_lens = len(self.tokenizer.tokenize(instruction))
381
+ else:
382
+ instruction_lens = 0
383
+
384
+ encoded_embeds = []
385
+ device = next(self.embedding_model.parameters()).device
386
+ for batch_dict in tqdm(data_loader, desc='encoding', mininterval=10):
387
+ features = self.prepare_kwargs_from_batch(batch_dict, instruction_lens, device=device)
388
+ embeds=self(**features)["sentence_embeddings"].squeeze(1)
389
+ encoded_embeds.append(embeds)
390
+ encoded_embeds = torch.cat(encoded_embeds, axis=0)
391
+ if "return_numpy" in kwargs and kwargs.get("return_numpy"):
392
+ encoded_embeds = encoded_embeds.cpu().detach().numpy()
393
+ return encoded_embeds
394
+
395
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, pool_mask: Optional[torch.Tensor]=None, return_dict: bool=True):
396
+ autocast_ctx = torch.autocast if torch.cuda.is_available() else nullcontext
397
+ with autocast_ctx("cuda"):
398
+ ## decoder only layer
399
+ outputs = self.embedding_model(
400
+ input_ids=input_ids,
401
+ attention_mask=attention_mask,
402
+ )
403
+ ## latent attention layer
404
+ embeds = self.latent_attention_model(
405
+ outputs.last_hidden_state,
406
+ pool_mask,
407
+ )
408
+ if not return_dict:
409
+ return (embeds,)
410
+ return {"sentence_embeddings": embeds}
411
+
412
+
413
+ @torch.no_grad()
414
+ def encode(self, prompts: List[str], instruction: str="", max_length: int=4096, **kwargs):
415
+ if self.padding_side == "right" and self.is_mask_instruction == True and len(instruction) > 0:
416
+ instruction_lens = len(self.tokenizer.tokenize(instruction))
417
+ else:
418
+ instruction_lens = 0
419
+
420
+ device = next(self.embedding_model.parameters()).device
421
+ batch_dict = input_transform_func(self.tokenizer,
422
+ {"input_texts": [prompt for prompt in prompts]},
423
+ always_add_eos=True,
424
+ max_length=max_length,
425
+ instruction=instruction)
426
+
427
+ features: NVEmbedFeatures = self.prepare_kwargs_from_batch(batch_dict, instruction_lens, device=device)
428
+ return self(**features)["sentence_embeddings"].squeeze(1)
429
+
430
+
431
+ ## AutoModel Register
432
+ AutoModel.register(NVEmbedConfig, NVEmbedModel)
433
+ AutoModel.register(LatentAttentionConfig, LatentAttentionModel)
434
+ AutoModel.register(BidirectionalMistralConfig, BidirectionalMistralModel)
435
+
436
+ ## Register for auto class
437
+ NVEmbedModel.register_for_auto_class("AutoModel")
438
+ LatentAttentionModel.register_for_auto_class("AutoModel")
439
+ BidirectionalMistralModel.register_for_auto_class("AutoModel")
modules.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ },
14
+ {
15
+ "idx": 2,
16
+ "name": "2",
17
+ "path": "2_Normalize",
18
+ "type": "sentence_transformers.models.Normalize"
19
+ }
20
+ ]
sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "max_seq_length": 32768,
3
+ "do_lower_case": false
4
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ }
29
+ },
30
+ "additional_special_tokens": [],
31
+ "bos_token": "<s>",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "</s>",
34
+ "legacy": true,
35
+ "model_max_length": 1000000000000000019884624838656,
36
+ "pad_token": "</s>",
37
+ "sp_model_kwargs": {},
38
+ "spaces_between_special_tokens": false,
39
+ "tokenizer_class": "LlamaTokenizer",
40
+ "unk_token": "<unk>",
41
+ "use_default_system_prompt": false
42
+ }