Files changed (5) hide show
  1. README.md +8 -32
  2. config.json +1 -1
  3. modeling_drama.py +152 -111
  4. modeling_drama_nested.py +0 -639
  5. modeling_drama_non_nested.py +0 -184
README.md CHANGED
@@ -60,10 +60,9 @@ model_name = "facebook/drama-base"
60
  device = "cuda" if torch.cuda.is_available() else "cpu"
61
  tokenizer = AutoTokenizer.from_pretrained(model_name)
62
  model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device)
63
- use_nested = False
64
- query_embs = model.encode_queries(tokenizer, queries, use_nested=use_nested)
65
- doc_embs = model.encode_documents(tokenizer, documents, use_nested=use_nested)
66
 
 
 
67
 
68
  scores = query_embs @ doc_embs.T
69
  print(scores.tolist())
@@ -78,8 +77,8 @@ print(scores.tolist())
78
  DRAMA models are trained using Matryoshka Representation Learning ([MRL](https://github.com/RAIVNLab/MRL)) to support flexible dimensionality. Both queries and documents can be encoded into smaller dimensions, such as 256, using the following:
79
 
80
  ```python
81
- query_embs = model.encode_queries(tokenizer, queries, dim=256, use_nested=use_nested)
82
- doc_embs = model.encode_documents(tokenizer, documents, dim=256, use_nested=use_nested)
83
 
84
  scores = query_embs @ doc_embs.T
85
  print(scores.tolist())
@@ -102,8 +101,8 @@ documents = [
102
 
103
  model = SentenceTransformer("facebook/drama-base", trust_remote_code=True)
104
 
105
- query_embs = model.encode(queries, prompt_name="query", use_nested=use_nested)
106
- doc_embs = model.encode(documents, use_nested=use_nested)
107
 
108
  scores = model.similarity(query_embs, doc_embs)
109
  print(scores.tolist())
@@ -129,8 +128,8 @@ documents = [
129
 
130
  model = SentenceTransformer("facebook/drama-base", truncate_dim=256, trust_remote_code=True)
131
 
132
- query_embs = model.encode(queries, prompt_name="query", use_nested=use_nested)
133
- doc_embs = model.encode(documents, use_nested=use_nested)
134
 
135
  scores = model.similarity(query_embs, doc_embs)
136
  print(scores.tolist())
@@ -166,26 +165,3 @@ If you find our paper or models helpful, please consider cite as follows:
166
  year={2025}
167
  }
168
  ```
169
-
170
- ## Efficient DRAMA
171
- ### Nested Tensors
172
- [Nested Tensors](https://docs.pytorch.org/docs/stable/nested.html) provide a way to handle ragged-shaped data within a single tensor, allowing for efficient operations on such data.
173
- They store data in a compact packed representation while offering a standard PyTorch tensor interface, making it easy to apply various
174
- operations.
175
- Nested Tensors are particularly advantageous for model deployments that perform inference on large batches of sequences with varying
176
- lengths. Traditional tensors require padding all sequences in a batch to the same length, which can be inefficient, especially when
177
- the batch includesmany short sequences and a single long sequence. Nested Tensors eliminate the need for padding, thus avoiding
178
- unnecessary computation on extra pad tokens. This results in more efficient processing of batches with varying sequence lengths.
179
-
180
- ### Performance
181
- Experiments have demonstrated a 1.7x to 2.3x (base,large and 1B) improvement in queries per second (QPS) for batch inference with sequences of varied lengths.
182
-
183
- ### Usage
184
- To enable Nested Tensors, simply set the use_nested variable to true. This will activate the nested jagged tensors and allow you to
185
- take advantage of efficient inference.
186
-
187
- > Prerequisites Package versions as this code have been tested with these versions. Please use these or some latest versions to avoid compatibility issues.
188
-
189
- >- Python: 3.12
190
- >- Transformers: 4.51.1
191
- >- PyTorch: 2.7.1
 
60
  device = "cuda" if torch.cuda.is_available() else "cpu"
61
  tokenizer = AutoTokenizer.from_pretrained(model_name)
62
  model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device)
 
 
 
63
 
64
+ query_embs = model.encode_queries(tokenizer, queries)
65
+ doc_embs = model.encode_documents(tokenizer, documents)
66
 
67
  scores = query_embs @ doc_embs.T
68
  print(scores.tolist())
 
77
  DRAMA models are trained using Matryoshka Representation Learning ([MRL](https://github.com/RAIVNLab/MRL)) to support flexible dimensionality. Both queries and documents can be encoded into smaller dimensions, such as 256, using the following:
78
 
79
  ```python
80
+ query_embs = model.encode_queries(tokenizer, queries, dim=256)
81
+ doc_embs = model.encode_documents(tokenizer, documents, dim=256)
82
 
83
  scores = query_embs @ doc_embs.T
84
  print(scores.tolist())
 
101
 
102
  model = SentenceTransformer("facebook/drama-base", trust_remote_code=True)
103
 
104
+ query_embs = model.encode(queries, prompt_name="query")
105
+ doc_embs = model.encode(documents)
106
 
107
  scores = model.similarity(query_embs, doc_embs)
108
  print(scores.tolist())
 
128
 
129
  model = SentenceTransformer("facebook/drama-base", truncate_dim=256, trust_remote_code=True)
130
 
131
+ query_embs = model.encode(queries, prompt_name="query")
132
+ doc_embs = model.encode(documents)
133
 
134
  scores = model.similarity(query_embs, doc_embs)
135
  print(scores.tolist())
 
165
  year={2025}
166
  }
167
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.json CHANGED
@@ -4,7 +4,7 @@
4
  "DramaModel"
5
  ],
6
  "auto_map": {
7
- "AutoModel": "modeling_drama.DramaModelWrapper"
8
  },
9
  "attention_bias": false,
10
  "attention_dropout": 0.0,
 
4
  "DramaModel"
5
  ],
6
  "auto_map": {
7
+ "AutoModel": "modeling_drama.DramaModel"
8
  },
9
  "attention_bias": false,
10
  "attention_dropout": 0.0,
modeling_drama.py CHANGED
@@ -1,125 +1,166 @@
1
- import sys
2
- import warnings
3
 
 
 
4
 
5
- def _check_torch_version():
6
- """Check if PyTorch version is >= 2.7.1"""
7
- try:
8
- import torch
9
 
10
- # Simple version comparison
11
- version_str = torch.__version__.split("+")[0] # Remove any suffixes like +cu118
12
- version_parts = version_str.split(".")
13
 
14
- # Compare major version
15
- if int(version_parts[0]) > 2:
16
- return True
17
- # Compare minor version
18
- elif int(version_parts[0]) == 2 and int(version_parts[1]) > 7:
19
- return True
20
- # Compare patch version
21
- elif (
22
- int(version_parts[0]) == 2
23
- and int(version_parts[1]) == 7
24
- and int(version_parts[2]) >= 1
25
- ):
26
- return True
27
-
28
- return False
29
- except (ImportError, AttributeError, IndexError, ValueError):
30
- return False
31
-
32
-
33
- def _check_transformers_version():
34
- """Check if Transformers version is >= 4.51.1"""
35
- try:
36
- import transformers
37
-
38
- # Simple version comparison
39
- version_str = transformers.__version__.split("+")[0] # Remove any suffixes
40
- version_parts = version_str.split(".")
41
-
42
- # Compare major version
43
- if int(version_parts[0]) > 4:
44
- return True
45
- # Compare minor version
46
- elif int(version_parts[0]) == 4 and int(version_parts[1]) > 51:
47
- return True
48
- # Compare patch version
49
- elif (
50
- int(version_parts[0]) == 4
51
- and int(version_parts[1]) == 51
52
- and int(version_parts[2]) >= 1
53
- ):
54
- return True
55
-
56
- return False
57
- except (ImportError, AttributeError, IndexError, ValueError):
58
- return False
59
-
60
-
61
- class DramaModelWrapper:
62
  """
63
- Factory class for DramaModel that returns the appropriate implementation
64
- based on the Python version.
65
-
66
- If Python version >= 3.12, returns an instance of the nested tensor implementation.
67
- Otherwise, returns an instance of the non-nested implementation.
68
  """
69
 
70
- @classmethod
71
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
72
  """
73
- Instantiate a pretrained model from a pre-trained model configuration.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- This method is required by the transformers library's auto model loading mechanism.
 
 
 
 
 
 
 
 
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  Args:
78
- pretrained_model_name_or_path: Path to the pretrained model or its name
79
- *model_args: Additional positional arguments to pass to the implementation
80
- **kwargs: Additional keyword arguments to pass to the implementation
81
-
82
  Returns:
83
- An instance of the appropriate DramaModel implementation.
84
  """
85
- # Check Python version
86
- use_nested = sys.version_info >= (3, 15)
87
- if not use_nested:
88
- warnings.warn(
89
- "Python version < 3.12 detected. Using non-nested implementation."
90
- )
91
- # For Python versions below 3.12, use the non-nested implementation
92
- from .modeling_drama_non_nested import DramaModel as NonNestedDramaModel
93
-
94
- return NonNestedDramaModel.from_pretrained(
95
- pretrained_model_name_or_path, *model_args, **kwargs
96
- )
97
-
98
- # Check PyTorch version
99
- if not _check_torch_version():
100
- warnings.warn(
101
- "PyTorch version < 2.7.1 detected. Falling back to non-nested implementation."
102
- )
103
- from .modeling_drama_non_nested import DramaModel as NonNestedDramaModel
104
-
105
- return NonNestedDramaModel.from_pretrained(
106
- pretrained_model_name_or_path, *model_args, **kwargs
107
- )
108
-
109
- # Check Transformers version
110
- if not _check_transformers_version():
111
- warnings.warn(
112
- "Transformers version < 4.51.1 detected. Falling back to non-nested implementation."
113
- )
114
- from .modeling_drama_non_nested import DramaModel as NonNestedDramaModel
115
-
116
- return NonNestedDramaModel.from_pretrained(
117
- pretrained_model_name_or_path, *model_args, **kwargs
118
- )
119
-
120
- # Use the nested tensor implementation if all requirements are met
121
- from .modeling_drama_nested import DramaModel as NestedDramaModel
122
-
123
- return NestedDramaModel.from_pretrained(
124
- pretrained_model_name_or_path, *model_args, **kwargs
125
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
 
2
 
3
+ import torch
4
+ import torch.nn.functional as F
5
 
6
+ from transformers import LlamaModel, LlamaConfig, PreTrainedTokenizer
7
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
 
 
8
 
 
 
 
9
 
10
+ class DramaModel(LlamaModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  """
12
+ DramaModel is a modified version of the LlamaModel that supports bi-directional attention
13
+ and provides query and document encoding functionalities.
 
 
 
14
  """
15
 
16
+ def __init__(self, config: LlamaConfig):
 
17
  """
18
+ Initializes the DramaModel by disabling causal masking in self-attention layers.
19
+ """
20
+ super().__init__(config)
21
+ for layer in self.layers:
22
+ layer.self_attn.is_causal = False
23
+ # query prefix
24
+ self.query_prefix = "Query: "
25
+ self.max_seq_len = 8192
26
+ self.hidden_size = config.hidden_size
27
+
28
+ def _update_causal_mask(
29
+ self,
30
+ attention_mask: torch.Tensor,
31
+ input_tensor: torch.Tensor,
32
+ cache_position: torch.Tensor,
33
+ past_seen_tokens=None,
34
+ output_attentions=False,
35
+ ):
36
+ """
37
+ Updates the causal mask for attention computations.
38
+ """
39
+ if self.config._attn_implementation == "flash_attention_2":
40
+ if attention_mask is not None and (attention_mask == 0.0).any():
41
+ return attention_mask
42
+ return None
43
+ if attention_mask is None or attention_mask.dim() == 4:
44
+ return attention_mask
45
+
46
+ return AttentionMaskConverter._expand_mask(
47
+ mask=attention_mask,
48
+ dtype=input_tensor.dtype,
49
+ )
50
 
51
+ def _average_pool(
52
+ self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
53
+ ) -> torch.Tensor:
54
+ """
55
+ Computes the average pooled representation of the last hidden states.
56
+ """
57
+ last_hidden = last_hidden_states.masked_fill(
58
+ ~attention_mask[..., None].bool(), 0.0
59
+ )
60
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
61
 
62
+ def _tokenize(
63
+ self,
64
+ tokenizer: PreTrainedTokenizer,
65
+ texts: list[str],
66
+ max_seq_len: int = None,
67
+ ):
68
+ """
69
+ Tokenizes input text sequences with optional sequence length restriction.
70
+ """
71
+ if max_seq_len is None:
72
+ max_seq_len = self.max_seq_len
73
+ tokenized = tokenizer(
74
+ texts,
75
+ padding=True,
76
+ truncation=True,
77
+ max_length=max_seq_len,
78
+ return_tensors='pt',
79
+ ).to(self.device)
80
+ return tokenized
81
+
82
+ def encode(self, input_ids, attention_mask, dim, *args, **kwargs):
83
+ """
84
+ Pass through the model and compute normalized embeddings.
85
+
86
  Args:
87
+ input_ids (torch.Tensor): Input token IDs.
88
+ attention_mask (torch.Tensor): Attention mask tensor.
89
+ dim (int): Dimensionality for output embeddings.
90
+
91
  Returns:
92
+ torch.Tensor: Normalized output embeddings.
93
  """
94
+ outputs = self.forward(
95
+ input_ids, attention_mask, *args, **kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
+ embeddings = self._average_pool(
98
+ outputs.last_hidden_state[:, :, :dim], attention_mask
99
+ )
100
+ # normalize embeddings
101
+ embeddings = F.normalize(embeddings, p=2, dim=1)
102
+ return embeddings
103
+
104
+ def encode_queries(
105
+ self,
106
+ tokenizer: PreTrainedTokenizer,
107
+ queries: list[str],
108
+ max_seq_len: int = None,
109
+ dim: int = None,
110
+ ):
111
+ """
112
+ Encodes a list of queries into embeddings.
113
+
114
+ Args:
115
+ tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
116
+ queries (list[str]): List of query texts.
117
+ max_seq_len (int, optional): Maximum sequence length.
118
+ dim (int, optional): Dimensionality for output embeddings.
119
+
120
+ Returns:
121
+ torch.Tensor: Encoded query embeddings in shape (num_queries, dim).
122
+ """
123
+ if not queries:
124
+ raise ValueError("queries must not be empty.")
125
+ if not isinstance(queries, list) or not all(isinstance(q, str) for q in queries):
126
+ raise ValueError("queries must be a list of strings.")
127
+ if tokenizer is None:
128
+ raise ValueError("tokenizer must not be None.")
129
+ if dim is not None and (dim < 1 or dim > self.hidden_size):
130
+ raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
131
+ queries = [self.query_prefix + query for query in queries]
132
+ tokenized_queries = self._tokenize(tokenizer, queries, max_seq_len)
133
+ embeddings = self.encode(**tokenized_queries, dim=dim)
134
+ return embeddings
135
+
136
+ def encode_documents(
137
+ self,
138
+ tokenizer: PreTrainedTokenizer,
139
+ documents: list[str],
140
+ max_seq_len: int = None,
141
+ dim: int = None,
142
+ ):
143
+ """
144
+ Encodes a list of documents into embeddings.
145
+
146
+ Args:
147
+ tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
148
+ documents (list[str]): List of document texts.
149
+ max_seq_len (int, optional): Maximum sequence length.
150
+ dim (int, optional): Dimensionality for output embeddings.
151
+
152
+ Returns:
153
+ torch.Tensor: Encoded document embeddings in shape (num_documents, dim).
154
+ """
155
+ if not documents:
156
+ raise ValueError("documents must not be empty.")
157
+ if not isinstance(documents, list) or not all(isinstance(d, str) for d in documents):
158
+ raise ValueError("documents must be a list of strings.")
159
+ if tokenizer is None:
160
+ raise ValueError("tokenizer must not be None.")
161
+ if dim is not None and (dim < 1 or dim > self.hidden_size):
162
+ raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
163
+ tokenized_documents = self._tokenize(tokenizer, documents, max_seq_len)
164
+ embeddings = self.encode(**tokenized_documents, dim=dim)
165
+ return embeddings
166
+
modeling_drama_nested.py DELETED
@@ -1,639 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- from torch.nested._internal.nested_tensor import nested_from_padded
7
-
8
- from transformers import (
9
- LlamaConfig,
10
- LlamaModel,
11
- LlamaPreTrainedModel,
12
- PreTrainedTokenizer,
13
- )
14
- from transformers.cache_utils import Cache, DynamicCache
15
- from transformers.modeling_attn_mask_utils import AttentionMaskConverter
16
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
17
- from transformers.modeling_outputs import BaseModelOutputWithPast
18
- from transformers.models.llama.modeling_llama import (
19
- LlamaAttention,
20
- LlamaDecoderLayer,
21
- LlamaMLP,
22
- LlamaRMSNorm,
23
- LlamaRotaryEmbedding,
24
- rotate_half,
25
- )
26
- from transformers.processing_utils import Unpack
27
-
28
-
29
- class ModifiedLlamaAttention(LlamaAttention):
30
- def __init__(self, *args: Any, **kwargs: Any) -> None:
31
- super().__init__(*args, **kwargs)
32
- self.is_causal = False
33
-
34
- def forward(
35
- self,
36
- hidden_states: torch.Tensor,
37
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
38
- attention_mask: Optional[torch.Tensor],
39
- past_key_value: Optional[Cache] = None,
40
- cache_position: Optional[torch.LongTensor] = None,
41
- **kwargs: Unpack[FlashAttentionKwargs],
42
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
43
- input_shape = hidden_states.shape[:-1]
44
- hidden_shape = (*input_shape, -1, self.head_dim)
45
-
46
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
47
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
48
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
49
-
50
- cos, sin = position_embeddings
51
- query_states, key_states = apply_rotary_pos_emb(
52
- query_states, key_states, cos, sin
53
- )
54
-
55
- if past_key_value is not None:
56
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
57
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
58
- key_states, value_states = past_key_value.update(
59
- key_states, value_states, self.layer_idx, cache_kwargs
60
- )
61
-
62
- if self.config._attn_implementation != "eager":
63
- if self.config._attn_implementation == "sdpa" and kwargs.get(
64
- "output_attentions", False
65
- ):
66
- warnings.warn(
67
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
68
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
69
- )
70
-
71
- attn_output, attn_weights = sdpa_attention_forward(
72
- self,
73
- query_states,
74
- key_states,
75
- value_states,
76
- attention_mask,
77
- dropout=0.0,
78
- scaling=self.scaling,
79
- is_causal=False,
80
- **kwargs,
81
- )
82
-
83
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
84
- attn_output = self.o_proj(attn_output)
85
- return attn_output, attn_weights
86
-
87
-
88
- def sdpa_attention_forward(
89
- module: torch.nn.Module,
90
- query: torch.Tensor,
91
- key: torch.Tensor,
92
- value: torch.Tensor,
93
- attention_mask: torch.Tensor,
94
- dropout: float = 0.0,
95
- scaling: Optional[float] = None,
96
- is_causal: Optional[bool] = None,
97
- **kwargs: Any,
98
- ) -> Tuple[torch.Tensor, None]:
99
- if hasattr(module, "num_key_value_groups"):
100
- if key.is_nested:
101
- key = repeat_jagged_kv(key, module.num_key_value_groups)
102
- value = repeat_jagged_kv(value, module.num_key_value_groups)
103
- else:
104
- key = repeat_dense_kv(key, module.num_key_value_groups)
105
- value = repeat_dense_kv(value, module.num_key_value_groups)
106
-
107
- causal_mask = attention_mask
108
- if attention_mask is not None and causal_mask.ndim == 4:
109
- causal_mask = causal_mask[:, :, :, : key.shape[-2]]
110
-
111
- # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
112
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
113
- query = query.contiguous()
114
- key = key.contiguous()
115
- value = value.contiguous()
116
-
117
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
118
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
119
- # Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
120
- if is_causal is None:
121
- is_causal = query.shape[2] > 1 and causal_mask is None
122
-
123
- # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
124
- # We convert it to a bool for the SDPA kernel that only accepts bools.
125
- if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
126
- is_causal = is_causal.item()
127
-
128
- attn_output = torch.nn.functional.scaled_dot_product_attention(
129
- query,
130
- key,
131
- value,
132
- attn_mask=causal_mask,
133
- dropout_p=dropout,
134
- scale=scaling,
135
- is_causal=is_causal,
136
- )
137
- attn_output = attn_output.transpose(1, 2).contiguous()
138
-
139
- return attn_output, None
140
-
141
-
142
- def repeat_jagged_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
143
- """
144
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
145
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
146
- """
147
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
148
- expand_shape = (batch, num_key_value_heads, -1, n_rep, head_dim)
149
- if n_rep == 1:
150
- return hidden_states
151
- hidden_states = (
152
- hidden_states.unsqueeze(3)
153
- .expand(expand_shape)
154
- .transpose(1, 2)
155
- .flatten(2, 3)
156
- .transpose(1, 2)
157
- )
158
- return hidden_states
159
-
160
-
161
- def repeat_dense_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
162
- """
163
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
164
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
165
- """
166
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
167
- if n_rep == 1:
168
- return hidden_states
169
- hidden_states = hidden_states[:, :, None, :, :].expand(
170
- batch, num_key_value_heads, n_rep, slen, head_dim
171
- )
172
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
173
-
174
-
175
- def apply_rotary_pos_emb(
176
- q: torch.Tensor,
177
- k: torch.Tensor,
178
- cos: torch.Tensor,
179
- sin: torch.Tensor,
180
- unsqueeze_dim: int = 1,
181
- ) -> Tuple[torch.Tensor, torch.Tensor]:
182
- """Applies Rotary Position Embedding to the query and key tensors.
183
-
184
- Args:
185
- q (`torch.Tensor`): The query tensor.
186
- k (`torch.Tensor`): The key tensor.
187
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
188
- sin (`torch.Tensor`): The sine part of the rotary embedding.
189
- position_ids (`torch.Tensor`, *optional*):
190
- Deprecated and unused.
191
- unsqueeze_dim (`int`, *optional*, defaults to 1):
192
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
193
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
194
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
195
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
196
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
197
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
198
- Returns:
199
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
200
- """
201
- cos = cos.unsqueeze(unsqueeze_dim)
202
- sin = sin.unsqueeze(unsqueeze_dim)
203
- if q.is_nested and k.is_nested:
204
- if q.layout != torch.jagged:
205
- raise NotImplementedError(f"Unsupported layout: {q.layout}")
206
- if k.layout != torch.jagged:
207
- raise NotImplementedError(f"Unsupported layout: {k.layout}")
208
- return _jagged_tensor_forward(q, k, cos, sin)
209
- else:
210
- return _padded_tensor_forward(q, k, cos, sin)
211
-
212
-
213
- def _jagged_tensor_forward(
214
- q: torch.Tensor,
215
- k: torch.Tensor,
216
- cos: torch.Tensor,
217
- sin: torch.Tensor,
218
- ) -> Tuple[torch.Tensor, torch.Tensor]:
219
- q_dense = q.to_padded_tensor(0.0)
220
- k_dense = k.to_padded_tensor(0.0)
221
- q_dense_embed = (q_dense * cos) + (rotate_half(q_dense) * sin)
222
- k_dense_embed = (k_dense * cos) + (rotate_half(k_dense) * sin)
223
- q_jagged_embed = convert_dense_to_jagged(q, q_dense_embed)
224
- k_jagged_embed = convert_dense_to_jagged(k, k_dense_embed)
225
- return q_jagged_embed, k_jagged_embed
226
-
227
-
228
- def _padded_tensor_forward(
229
- q: torch.Tensor,
230
- k: torch.Tensor,
231
- cos: torch.Tensor,
232
- sin: torch.Tensor,
233
- ) -> Tuple[torch.Tensor, torch.Tensor]:
234
- q_embed = (q * cos) + (rotate_half(q) * sin)
235
- k_embed = (k * cos) + (rotate_half(k) * sin)
236
- return q_embed, k_embed
237
-
238
-
239
- def convert_dense_to_jagged(nested_q: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
240
- padded_max_S = nested_q._get_max_seqlen()
241
- total_L = nested_q._values.shape[nested_q._ragged_idx - 1]
242
- if padded_max_S is None:
243
- # use upper bound on max seqlen if it's not present
244
- padded_max_S = total_L
245
-
246
- # convert dense tensor -> jagged
247
- q = q.expand(
248
- [
249
- x if i != nested_q._ragged_idx else padded_max_S
250
- for i, x in enumerate(q.shape)
251
- ]
252
- )
253
- nested_result = nested_from_padded(
254
- q,
255
- offsets=nested_q._offsets,
256
- ragged_idx=nested_q._ragged_idx,
257
- sum_S=total_L,
258
- min_seqlen=nested_q._get_min_seqlen(),
259
- max_seqlen=padded_max_S,
260
- )
261
- return nested_result
262
-
263
-
264
- class ModifiedLlamaDecoderLayer(LlamaDecoderLayer):
265
- def __init__(self, config: LlamaConfig, layer_idx: int) -> None:
266
- nn.Module.__init__(self)
267
- self.hidden_size: int = config.hidden_size
268
-
269
- self.self_attn = ModifiedLlamaAttention(config=config, layer_idx=layer_idx)
270
-
271
- self.mlp = LlamaMLP(config)
272
- self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
273
- self.post_attention_layernorm = LlamaRMSNorm(
274
- config.hidden_size, eps=config.rms_norm_eps
275
- )
276
-
277
-
278
- class LlamaBiModel(LlamaModel):
279
- def __init__(self, config: LlamaConfig) -> None:
280
- LlamaPreTrainedModel.__init__(self, config)
281
- self.padding_idx: int = config.pad_token_id
282
- self.vocab_size: int = config.vocab_size
283
-
284
- self.embed_tokens = nn.Embedding(
285
- config.vocab_size, config.hidden_size, self.padding_idx
286
- )
287
- self.layers = nn.ModuleList(
288
- [
289
- ModifiedLlamaDecoderLayer(config, layer_idx)
290
- for layer_idx in range(config.num_hidden_layers)
291
- ]
292
- )
293
- self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
294
- self.rotary_emb = LlamaRotaryEmbedding(config=config)
295
- self.gradient_checkpointing = False
296
-
297
- # Initialize weights and apply final processing
298
- self.post_init()
299
-
300
- def _update_causal_mask(
301
- self,
302
- attention_mask: torch.Tensor,
303
- input_tensor: torch.Tensor,
304
- cache_position: torch.Tensor,
305
- past_seen_tokens=None,
306
- output_attentions=False,
307
- ):
308
- """
309
- Updates the causal mask for attention computations.
310
- """
311
- if self.config._attn_implementation == "flash_attention_2":
312
- if attention_mask is not None and (attention_mask == 0.0).any():
313
- return attention_mask
314
- return None
315
- if attention_mask is None or attention_mask.dim() == 4:
316
- return attention_mask
317
-
318
- return AttentionMaskConverter._expand_mask(
319
- mask=attention_mask,
320
- dtype=input_tensor.dtype,
321
- )
322
-
323
- def forward(
324
- self,
325
- input_ids: Optional[torch.LongTensor] = None,
326
- attention_mask: Optional[torch.Tensor] = None,
327
- position_ids: Optional[torch.LongTensor] = None,
328
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
329
- inputs_embeds: Optional[torch.FloatTensor] = None,
330
- use_cache: Optional[bool] = None,
331
- output_attentions: Optional[bool] = None,
332
- output_hidden_states: Optional[bool] = None,
333
- return_dict: Optional[bool] = None,
334
- cache_position: Optional[torch.LongTensor] = None,
335
- ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
336
- output_attentions = (
337
- output_attentions
338
- if output_attentions is not None
339
- else self.config.output_attentions
340
- )
341
- output_hidden_states = (
342
- output_hidden_states
343
- if output_hidden_states is not None
344
- else self.config.output_hidden_states
345
- )
346
- # use_cache = use_cache if use_cache is not None else self.config.use_cache
347
- use_cache = False
348
- return_dict = (
349
- return_dict if return_dict is not None else self.config.use_return_dict
350
- )
351
-
352
- if (input_ids is None) ^ (inputs_embeds is not None):
353
- raise ValueError(
354
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
355
- )
356
- if self.gradient_checkpointing and self.training and use_cache:
357
- warnings.warn(
358
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.",
359
- DeprecationWarning,
360
- stacklevel=2,
361
- )
362
- use_cache = False
363
-
364
- if inputs_embeds is None:
365
- inputs_embeds = self.embed_tokens(input_ids)
366
-
367
- return_legacy_cache = False
368
- if (
369
- use_cache and not isinstance(past_key_values, Cache) and not self.training
370
- ): # kept for BC (non `Cache` `past_key_values` inputs)
371
- return_legacy_cache = True
372
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
373
- warnings.warn(
374
- "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
375
- "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)",
376
- DeprecationWarning,
377
- stacklevel=2,
378
- )
379
-
380
- if cache_position is None:
381
- past_seen_tokens = (
382
- past_key_values.get_seq_length() if past_key_values is not None else 0
383
- )
384
- if inputs_embeds.is_nested:
385
- seq_len = inputs_embeds._get_max_seqlen()
386
- else:
387
- seq_len = inputs_embeds.shape[1]
388
- cache_position = torch.arange(
389
- past_seen_tokens,
390
- past_seen_tokens + seq_len,
391
- device=inputs_embeds.device,
392
- )
393
- if position_ids is None:
394
- position_ids = cache_position.unsqueeze(0)
395
- if not inputs_embeds.is_nested:
396
- causal_mask = self._update_causal_mask(
397
- attention_mask,
398
- inputs_embeds,
399
- cache_position,
400
- past_key_values,
401
- )
402
-
403
- else:
404
- causal_mask = None
405
- hidden_states = inputs_embeds
406
-
407
- # create position embeddings to be shared across the decoder layers
408
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
409
-
410
- # decoder layers
411
- all_hidden_states = () if output_hidden_states else None
412
- all_self_attns = () if output_attentions else None
413
- next_decoder_cache = None
414
-
415
- for decoder_layer in self.layers:
416
- if output_hidden_states:
417
- all_hidden_states += (hidden_states,)
418
-
419
- if self.gradient_checkpointing and self.training:
420
- layer_outputs = self._gradient_checkpointing_func(
421
- decoder_layer.__call__,
422
- hidden_states,
423
- causal_mask,
424
- position_ids,
425
- past_key_values,
426
- output_attentions,
427
- use_cache,
428
- cache_position,
429
- position_embeddings,
430
- )
431
- else:
432
- layer_outputs = decoder_layer(
433
- hidden_states,
434
- attention_mask=causal_mask,
435
- position_ids=position_ids,
436
- past_key_value=past_key_values,
437
- output_attentions=output_attentions,
438
- use_cache=use_cache,
439
- cache_position=cache_position,
440
- position_embeddings=position_embeddings,
441
- )
442
-
443
- hidden_states = layer_outputs[0]
444
-
445
- if use_cache:
446
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
447
-
448
- if output_attentions:
449
- all_self_attns += (layer_outputs[1],)
450
-
451
- hidden_states = self.norm(hidden_states)
452
-
453
- # add hidden states from the last decoder layer
454
- if output_hidden_states:
455
- all_hidden_states += (hidden_states,)
456
-
457
- next_cache = next_decoder_cache if use_cache else None
458
- if return_legacy_cache:
459
- next_cache = next_cache.to_legacy_cache()
460
-
461
- if not return_dict:
462
- return tuple(
463
- v
464
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
465
- if v is not None
466
- )
467
- return BaseModelOutputWithPast(
468
- last_hidden_state=hidden_states,
469
- past_key_values=next_cache,
470
- hidden_states=all_hidden_states,
471
- attentions=all_self_attns,
472
- )
473
-
474
-
475
- class DramaModel(LlamaBiModel):
476
- """
477
- DramaModel is a modified version of the LlamaModel that supports bi-directional attention
478
- and provides query and document encoding functionalities.
479
- """
480
-
481
- def __init__(self, config: LlamaConfig):
482
- """
483
- Initializes the DramaModel by disabling causal masking in self-attention layers.
484
- """
485
- super().__init__(config)
486
- for layer in self.layers:
487
- layer.self_attn.is_causal = False
488
- # query prefix
489
- self.query_prefix = "Query: "
490
- self.max_seq_len = 8192
491
- self.hidden_size = config.hidden_size
492
-
493
- def _average_pool(
494
- self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
495
- ) -> torch.Tensor:
496
- """
497
- Computes the average pooled representation of the last hidden states.
498
- """
499
- last_hidden = last_hidden_states.masked_fill(
500
- ~attention_mask[..., None].bool(), 0.0
501
- )
502
- return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
503
-
504
- def _tokenize(
505
- self,
506
- tokenizer: PreTrainedTokenizer,
507
- texts: list[str],
508
- max_seq_len: int = None,
509
- use_nested: bool = False,
510
- ):
511
- """
512
- Tokenizes input text sequences with optional sequence length restriction.
513
- """
514
- if max_seq_len is None:
515
- max_seq_len = self.max_seq_len
516
- if use_nested:
517
- tokenized = tokenizer(
518
- texts,
519
- truncation=True,
520
- max_length=max_seq_len,
521
- return_length=True,
522
- )
523
- tokenized.input_ids = torch.nested.nested_tensor(
524
- tokenized.input_ids, layout=torch.jagged
525
- ).to(self.device)
526
- tokenized.attention_mask = None
527
- else:
528
- tokenized = tokenizer(
529
- texts,
530
- padding=True,
531
- truncation=True,
532
- max_length=max_seq_len,
533
- return_tensors="pt",
534
- ).to(self.device)
535
- tokenizer_ouput = {}
536
- tokenizer_ouput["input_ids"] = tokenized.input_ids
537
- tokenizer_ouput["attention_mask"] = tokenized.attention_mask
538
- return tokenizer_ouput
539
-
540
- def encode(self, input_ids, attention_mask, dim, *args, **kwargs):
541
- """
542
- Pass through the model and compute normalized embeddings.
543
-
544
- Args:
545
- input_ids (torch.Tensor): Input token IDs.
546
- attention_mask (torch.Tensor): Attention mask tensor.
547
- dim (int): Dimensionality for output embeddings.
548
-
549
- Returns:
550
- torch.Tensor: Normalized output embeddings.
551
- """
552
-
553
- outputs = self.forward(
554
- input_ids, attention_mask, *args, **kwargs
555
- ).last_hidden_state
556
- if not outputs.is_nested:
557
- if dim is not None:
558
- outputs = outputs[:, :, :dim]
559
- embeddings = self._average_pool(outputs, attention_mask)
560
- else:
561
- if dim is not None:
562
- outputs, _ = outputs.split_with_sizes(
563
- split_sizes=[dim, outputs.shape[-1] - dim], dim=-1
564
- )
565
- embeddings = outputs.sum(dim=-2)
566
- # normalize embeddings
567
- embeddings = F.normalize(embeddings, p=2, dim=1)
568
- return embeddings
569
-
570
- def encode_queries(
571
- self,
572
- tokenizer: PreTrainedTokenizer,
573
- queries: list[str],
574
- max_seq_len: int = None,
575
- dim: int = None,
576
- use_nested: bool = False,
577
- ):
578
- """
579
- Encodes a list of queries into embeddings.
580
-
581
- Args:
582
- tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
583
- queries (list[str]): List of query texts.
584
- max_seq_len (int, optional): Maximum sequence length.
585
- dim (int, optional): Dimensionality for output embeddings.
586
-
587
- Returns:
588
- torch.Tensor: Encoded query embeddings in shape (num_queries, dim).
589
- """
590
- if not queries:
591
- raise ValueError("queries must not be empty.")
592
- if not isinstance(queries, list) or not all(
593
- isinstance(q, str) for q in queries
594
- ):
595
- raise ValueError("queries must be a list of strings.")
596
- if tokenizer is None:
597
- raise ValueError("tokenizer must not be None.")
598
- if dim is not None and (dim < 1 or dim > self.hidden_size):
599
- raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
600
- queries = [self.query_prefix + query for query in queries]
601
- tokenized_queries = self._tokenize(tokenizer, queries, max_seq_len, use_nested)
602
- embeddings = self.encode(**tokenized_queries, dim=dim)
603
- return embeddings
604
-
605
- def encode_documents(
606
- self,
607
- tokenizer: PreTrainedTokenizer,
608
- documents: list[str],
609
- max_seq_len: int = None,
610
- dim: int = None,
611
- use_nested: bool = False,
612
- ):
613
- """
614
- Encodes a list of documents into embeddings.
615
-
616
- Args:
617
- tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
618
- documents (list[str]): List of document texts.
619
- max_seq_len (int, optional): Maximum sequence length.
620
- dim (int, optional): Dimensionality for output embeddings.
621
-
622
- Returns:
623
- torch.Tensor: Encoded document embeddings in shape (num_documents, dim).
624
- """
625
- if not documents:
626
- raise ValueError("documents must not be empty.")
627
- if not isinstance(documents, list) or not all(
628
- isinstance(d, str) for d in documents
629
- ):
630
- raise ValueError("documents must be a list of strings.")
631
- if tokenizer is None:
632
- raise ValueError("tokenizer must not be None.")
633
- if dim is not None and (dim < 1 or dim > self.hidden_size):
634
- raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
635
- tokenized_documents = self._tokenize(
636
- tokenizer, documents, max_seq_len, use_nested
637
- )
638
- embeddings = self.encode(**tokenized_documents, dim=dim)
639
- return embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_drama_non_nested.py DELETED
@@ -1,184 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import warnings
4
-
5
- import torch
6
- import torch.nn.functional as F
7
-
8
- from transformers import LlamaConfig, LlamaModel, PreTrainedTokenizer
9
- from transformers.modeling_attn_mask_utils import AttentionMaskConverter
10
-
11
-
12
- class DramaModel(LlamaModel):
13
- """
14
- DramaModel is a modified version of the LlamaModel that supports bi-directional attention
15
- and provides query and document encoding functionalities.
16
- """
17
-
18
- def __init__(self, config: LlamaConfig):
19
- """
20
- Initializes the DramaModel by disabling causal masking in self-attention layers.
21
- """
22
- super().__init__(config)
23
- for layer in self.layers:
24
- layer.self_attn.is_causal = False
25
- # query prefix
26
- self.query_prefix = "Query: "
27
- self.max_seq_len = 8192
28
- self.hidden_size = config.hidden_size
29
-
30
- def _update_causal_mask(
31
- self,
32
- attention_mask: torch.Tensor,
33
- input_tensor: torch.Tensor,
34
- cache_position: torch.Tensor,
35
- past_seen_tokens=None,
36
- output_attentions=False,
37
- ):
38
- """
39
- Updates the causal mask for attention computations.
40
- """
41
- if self.config._attn_implementation == "flash_attention_2":
42
- if attention_mask is not None and (attention_mask == 0.0).any():
43
- return attention_mask
44
- return None
45
- if attention_mask is None or attention_mask.dim() == 4:
46
- return attention_mask
47
-
48
- return AttentionMaskConverter._expand_mask(
49
- mask=attention_mask,
50
- dtype=input_tensor.dtype,
51
- )
52
-
53
- def _average_pool(
54
- self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
55
- ) -> torch.Tensor:
56
- """
57
- Computes the average pooled representation of the last hidden states.
58
- """
59
- last_hidden = last_hidden_states.masked_fill(
60
- ~attention_mask[..., None].bool(), 0.0
61
- )
62
- return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
63
-
64
- def _tokenize(
65
- self,
66
- tokenizer: PreTrainedTokenizer,
67
- texts: list[str],
68
- max_seq_len: int = None,
69
- use_nested: bool = False, # Added for API compatibility with nested version
70
- ):
71
- """
72
- Tokenizes input text sequences with optional sequence length restriction.
73
- """
74
- if max_seq_len is None:
75
- max_seq_len = self.max_seq_len
76
- tokenized = tokenizer(
77
- texts,
78
- padding=True,
79
- truncation=True,
80
- max_length=max_seq_len,
81
- return_tensors="pt",
82
- ).to(self.device)
83
- return tokenized
84
-
85
- def encode(self, input_ids, attention_mask, dim, *args, **kwargs):
86
- """
87
- Pass through the model and compute normalized embeddings.
88
-
89
- Args:
90
- input_ids (torch.Tensor): Input token IDs.
91
- attention_mask (torch.Tensor): Attention mask tensor.
92
- dim (int): Dimensionality for output embeddings.
93
-
94
- Returns:
95
- torch.Tensor: Normalized output embeddings.
96
- """
97
- outputs = self.forward(input_ids, attention_mask, *args, **kwargs)
98
- embeddings = self._average_pool(
99
- outputs.last_hidden_state[:, :, :dim], attention_mask
100
- )
101
- # normalize embeddings
102
- embeddings = F.normalize(embeddings, p=2, dim=1)
103
- return embeddings
104
-
105
- def encode_queries(
106
- self,
107
- tokenizer: PreTrainedTokenizer,
108
- queries: list[str],
109
- max_seq_len: int = None,
110
- dim: int = None,
111
- use_nested: bool = False, # Added for API compatibility with nested version
112
- ):
113
- """
114
- Encodes a list of queries into embeddings.
115
-
116
- Args:
117
- tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
118
- queries (list[str]): List of query texts.
119
- max_seq_len (int, optional): Maximum sequence length.
120
- dim (int, optional): Dimensionality for output embeddings.
121
-
122
- Returns:
123
- torch.Tensor: Encoded query embeddings in shape (num_queries, dim).
124
- """
125
- if not queries:
126
- raise ValueError("queries must not be empty.")
127
- if not isinstance(queries, list) or not all(
128
- isinstance(q, str) for q in queries
129
- ):
130
- raise ValueError("queries must be a list of strings.")
131
- if tokenizer is None:
132
- raise ValueError("tokenizer must not be None.")
133
- if dim is not None and (dim < 1 or dim > self.hidden_size):
134
- raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
135
- if use_nested:
136
- warnings.warn(
137
- "use_nested is not supported due to package import versions.",
138
- UserWarning,
139
- )
140
- queries = [self.query_prefix + query for query in queries]
141
- tokenized_queries = self._tokenize(tokenizer, queries, max_seq_len, use_nested)
142
- embeddings = self.encode(**tokenized_queries, dim=dim)
143
- return embeddings
144
-
145
- def encode_documents(
146
- self,
147
- tokenizer: PreTrainedTokenizer,
148
- documents: list[str],
149
- max_seq_len: int = None,
150
- dim: int = None,
151
- use_nested: bool = False, # Added for API compatibility with nested version
152
- ):
153
- """
154
- Encodes a list of documents into embeddings.
155
-
156
- Args:
157
- tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
158
- documents (list[str]): List of document texts.
159
- max_seq_len (int, optional): Maximum sequence length.
160
- dim (int, optional): Dimensionality for output embeddings.
161
-
162
- Returns:
163
- torch.Tensor: Encoded document embeddings in shape (num_documents, dim).
164
- """
165
- if not documents:
166
- raise ValueError("documents must not be empty.")
167
- if not isinstance(documents, list) or not all(
168
- isinstance(d, str) for d in documents
169
- ):
170
- raise ValueError("documents must be a list of strings.")
171
- if tokenizer is None:
172
- raise ValueError("tokenizer must not be None.")
173
- if dim is not None and (dim < 1 or dim > self.hidden_size):
174
- raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
175
- if use_nested:
176
- warnings.warn(
177
- "use_nested is not supported due to package import versions.",
178
- UserWarning,
179
- )
180
- tokenized_documents = self._tokenize(
181
- tokenizer, documents, max_seq_len, use_nested
182
- )
183
- embeddings = self.encode(**tokenized_documents, dim=dim)
184
- return embeddings