aletlvl commited on
Commit
f495fe3
·
verified ·
1 Parent(s): 211103a

Modeling with embeddings function

Browse files
Files changed (1) hide show
  1. modeling_nicheformer.py +66 -0
modeling_nicheformer.py CHANGED
@@ -100,6 +100,54 @@ class NicheformerModel(NicheformerPreTrainedModel):
100
  )
101
 
102
  return transformer_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  class NicheformerForMaskedLM(NicheformerPreTrainedModel):
105
  def __init__(self, config: NicheformerConfig):
@@ -160,6 +208,24 @@ class NicheformerForMaskedLM(NicheformerPreTrainedModel):
160
  hidden_states=transformer_output,
161
  )
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  def complete_masking(batch, masking_p, n_tokens):
165
  """Apply masking to input batch for masked language modeling.
 
100
  )
101
 
102
  return transformer_output
103
+
104
+ def get_embeddings(self, input_ids, attention_mask=None, layer: int = -1, with_context: bool = False) -> torch.Tensor:
105
+ """Get embeddings from the model.
106
+
107
+ Args:
108
+ input_ids: Input token IDs
109
+ attention_mask: Attention mask
110
+ layer: Which transformer layer to extract embeddings from (-1 means last layer)
111
+ with_context: Whether to include context tokens in the embeddings
112
+
113
+ Returns:
114
+ torch.Tensor: Embeddings tensor
115
+ """
116
+ # Get token embeddings and positional encodings
117
+ token_embedding = self.embeddings(input_ids)
118
+
119
+ if self.config.learnable_pe:
120
+ pos_embedding = self.positional_embedding(self.pos.to(token_embedding.device))
121
+ embeddings = self.dropout(token_embedding + pos_embedding)
122
+ else:
123
+ embeddings = self.positional_embedding(token_embedding)
124
+
125
+ # Process through transformer layers up to desired layer
126
+ if layer < 0:
127
+ layer = self.config.nlayers + layer # -1 means last layer
128
+
129
+ # Convert attention_mask to boolean and invert it for transformer's src_key_padding_mask
130
+ if attention_mask is not None:
131
+ padding_mask = ~attention_mask.bool()
132
+ else:
133
+ padding_mask = None
134
+
135
+ # Process through each layer up to the desired one
136
+ for i in range(layer + 1):
137
+ embeddings = self.encoder.layers[i](
138
+ embeddings,
139
+ src_key_padding_mask=padding_mask,
140
+ is_causal=False
141
+ )
142
+
143
+ # Remove context tokens (first 3 tokens) if not needed
144
+ if not with_context:
145
+ embeddings = embeddings[:, 3:, :]
146
+
147
+ # Mean pooling over sequence dimension
148
+ embeddings = embeddings.mean(dim=1)
149
+
150
+ return embeddings
151
 
152
  class NicheformerForMaskedLM(NicheformerPreTrainedModel):
153
  def __init__(self, config: NicheformerConfig):
 
208
  hidden_states=transformer_output,
209
  )
210
 
211
+ def get_embeddings(self, input_ids, attention_mask=None, layer: int = -1, with_context: bool = False) -> torch.Tensor:
212
+ """Get embeddings from the model.
213
+
214
+ Args:
215
+ input_ids: Input token IDs
216
+ attention_mask: Attention mask
217
+ layer: Which transformer layer to extract embeddings from (-1 means last layer)
218
+ with_context: Whether to include context tokens in the embeddings
219
+
220
+ Returns:
221
+ torch.Tensor: Embeddings tensor
222
+ """
223
+ return self.nicheformer.get_embeddings(
224
+ input_ids=input_ids,
225
+ attention_mask=attention_mask,
226
+ layer=layer,
227
+ with_context=with_context
228
+ )
229
 
230
  def complete_masking(batch, masking_p, n_tokens):
231
  """Apply masking to input batch for masked language modeling.