Modeling with embeddings function
Browse files- modeling_nicheformer.py +66 -0
modeling_nicheformer.py
CHANGED
|
@@ -100,6 +100,54 @@ class NicheformerModel(NicheformerPreTrainedModel):
|
|
| 100 |
)
|
| 101 |
|
| 102 |
return transformer_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
class NicheformerForMaskedLM(NicheformerPreTrainedModel):
|
| 105 |
def __init__(self, config: NicheformerConfig):
|
|
@@ -160,6 +208,24 @@ class NicheformerForMaskedLM(NicheformerPreTrainedModel):
|
|
| 160 |
hidden_states=transformer_output,
|
| 161 |
)
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
def complete_masking(batch, masking_p, n_tokens):
|
| 165 |
"""Apply masking to input batch for masked language modeling.
|
|
|
|
| 100 |
)
|
| 101 |
|
| 102 |
return transformer_output
|
| 103 |
+
|
| 104 |
+
def get_embeddings(self, input_ids, attention_mask=None, layer: int = -1, with_context: bool = False) -> torch.Tensor:
|
| 105 |
+
"""Get embeddings from the model.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
input_ids: Input token IDs
|
| 109 |
+
attention_mask: Attention mask
|
| 110 |
+
layer: Which transformer layer to extract embeddings from (-1 means last layer)
|
| 111 |
+
with_context: Whether to include context tokens in the embeddings
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
torch.Tensor: Embeddings tensor
|
| 115 |
+
"""
|
| 116 |
+
# Get token embeddings and positional encodings
|
| 117 |
+
token_embedding = self.embeddings(input_ids)
|
| 118 |
+
|
| 119 |
+
if self.config.learnable_pe:
|
| 120 |
+
pos_embedding = self.positional_embedding(self.pos.to(token_embedding.device))
|
| 121 |
+
embeddings = self.dropout(token_embedding + pos_embedding)
|
| 122 |
+
else:
|
| 123 |
+
embeddings = self.positional_embedding(token_embedding)
|
| 124 |
+
|
| 125 |
+
# Process through transformer layers up to desired layer
|
| 126 |
+
if layer < 0:
|
| 127 |
+
layer = self.config.nlayers + layer # -1 means last layer
|
| 128 |
+
|
| 129 |
+
# Convert attention_mask to boolean and invert it for transformer's src_key_padding_mask
|
| 130 |
+
if attention_mask is not None:
|
| 131 |
+
padding_mask = ~attention_mask.bool()
|
| 132 |
+
else:
|
| 133 |
+
padding_mask = None
|
| 134 |
+
|
| 135 |
+
# Process through each layer up to the desired one
|
| 136 |
+
for i in range(layer + 1):
|
| 137 |
+
embeddings = self.encoder.layers[i](
|
| 138 |
+
embeddings,
|
| 139 |
+
src_key_padding_mask=padding_mask,
|
| 140 |
+
is_causal=False
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Remove context tokens (first 3 tokens) if not needed
|
| 144 |
+
if not with_context:
|
| 145 |
+
embeddings = embeddings[:, 3:, :]
|
| 146 |
+
|
| 147 |
+
# Mean pooling over sequence dimension
|
| 148 |
+
embeddings = embeddings.mean(dim=1)
|
| 149 |
+
|
| 150 |
+
return embeddings
|
| 151 |
|
| 152 |
class NicheformerForMaskedLM(NicheformerPreTrainedModel):
|
| 153 |
def __init__(self, config: NicheformerConfig):
|
|
|
|
| 208 |
hidden_states=transformer_output,
|
| 209 |
)
|
| 210 |
|
| 211 |
+
def get_embeddings(self, input_ids, attention_mask=None, layer: int = -1, with_context: bool = False) -> torch.Tensor:
|
| 212 |
+
"""Get embeddings from the model.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
input_ids: Input token IDs
|
| 216 |
+
attention_mask: Attention mask
|
| 217 |
+
layer: Which transformer layer to extract embeddings from (-1 means last layer)
|
| 218 |
+
with_context: Whether to include context tokens in the embeddings
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
torch.Tensor: Embeddings tensor
|
| 222 |
+
"""
|
| 223 |
+
return self.nicheformer.get_embeddings(
|
| 224 |
+
input_ids=input_ids,
|
| 225 |
+
attention_mask=attention_mask,
|
| 226 |
+
layer=layer,
|
| 227 |
+
with_context=with_context
|
| 228 |
+
)
|
| 229 |
|
| 230 |
def complete_masking(batch, masking_p, n_tokens):
|
| 231 |
"""Apply masking to input batch for masked language modeling.
|