File size: 6,112 Bytes
4b62d89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import torch
import torch.nn.functional as F
from sentence_transformers.models import Transformer as BaseTransformer
class EmbeddingModel(BaseTransformer):
"""Wrapper model for extracting embeddings from a causal language model using SentenceTransformer framework."""
def __init__(self, model_name_or_path=None, base_model=None, tokenizer=None, max_seq_length=512, pooling="last", **kwargs):
"""
Initialize the embedding model with a base model, tokenizer, and pooling strategy.
Args:
model_name_or_path: HuggingFace model name or path (used by SentenceTransformer)
base_model: Pre-initialized model (prioritized over model_name_or_path)
tokenizer: Tokenizer to use (must be provided if base_model is provided)
max_seq_length: Maximum sequence length
pooling: Pooling strategy ("mean" or "last" token pooling)
"""
# If we're given a base_model directly, use that instead of loading from model_name_or_path
if base_model is not None:
if tokenizer is None:
raise ValueError("If base_model is provided, tokenizer must also be provided")
# Skip the normal initialization - we'll do that manually
# This just initializes the parent nn.Module - we'll handle the rest ourselves
super(BaseTransformer, self).__init__()
# Set up the model configuration manually
self.config = base_model.config
self.max_seq_length = max_seq_length
self.auto_model = base_model
self._tokenizer = tokenizer
self.tokenizer = tokenizer # For compatibility
self.do_lower_case = getattr(tokenizer, "do_lower_case", False)
# For certain models (like Llama), ensure that padding_idx is set correctly
self.padding_idx = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
# Additional attributes from BaseTransformer
self.name = 'Transformer'
self.backend = "huggingface_transformers" # Default backend
else:
# Use standard initialization from BaseTransformer
super().__init__(model_name_or_path=model_name_or_path, max_seq_length=max_seq_length, **kwargs)
self.pooling = pooling
self.embedding_dim = self.auto_model.config.hidden_size
# Remove lm_head if it exists to save memory
if hasattr(self.auto_model, "lm_head"):
delattr(self.auto_model, "lm_head")
def forward(self, features):
"""
Forward pass through the model to get embeddings.
Adapted to work with SentenceTransformer's expected format.
Args:
features: Dictionary with 'input_ids', 'attention_mask', etc.
Returns:
Dictionary with embeddings
"""
input_ids = features['input_ids']
attention_mask = features['attention_mask']
# Get the model outputs
outputs = self.auto_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True
)
# Extract embeddings using the specified pooling
hidden_states = outputs.hidden_states[-1]
# Get embeddings using our pooling method
embeddings = self._get_embeddings(
hidden_states,
input_ids,
self._tokenizer.eos_token_id if hasattr(self, '_tokenizer') else self.tokenizer.eos_token_id
)
# Return in the format expected by SentenceTransformer
return {'token_embeddings': hidden_states, 'sentence_embedding': embeddings}
def _get_embeddings(self, hidden_states, input_ids, eos_token_id, pooling=None):
"""Extract embeddings from hidden states using the specified pooling strategy."""
# Extract embeddings using the pooling strategy
if pooling is None:
pooling = self.pooling
batch_size = input_ids.shape[0]
hidden_dim = hidden_states.size(-1)
embeddings = torch.zeros(batch_size, hidden_dim, device=hidden_states.device)
tokenizer = self._tokenizer if hasattr(self, '_tokenizer') else self.tokenizer
if pooling == "mean":
attention_mask = (input_ids != tokenizer.pad_token_id).float()
sum_embeddings = torch.sum(
hidden_states * attention_mask.unsqueeze(-1), dim=1
)
input_mask_sum = torch.sum(attention_mask, dim=1).unsqueeze(-1)
input_mask_sum = torch.clamp(input_mask_sum, min=1e-9)
embeddings = sum_embeddings / input_mask_sum
else:
eos_positions = (input_ids == eos_token_id).nonzero(as_tuple=True)
batch_indices = eos_positions[0]
token_positions = eos_positions[1]
has_eos = torch.zeros(
batch_size, dtype=torch.bool, device=hidden_states.device
)
has_eos[batch_indices] = True
unique_batch_indices = batch_indices.unique()
for i in unique_batch_indices:
idx = (batch_indices == i).nonzero(as_tuple=True)[0][0]
embeddings[i] = hidden_states[i, token_positions[idx]]
non_eos_indices = (~has_eos).nonzero(as_tuple=True)[0]
if len(non_eos_indices) > 0:
for i in non_eos_indices:
mask = (input_ids[i] != tokenizer.pad_token_id).nonzero(
as_tuple=True
)[0]
embeddings[i] = hidden_states[i, mask[-1]]
return embeddings
def get_sentence_embedding_dimension(self):
"""Return the dimension of the sentence embeddings."""
return self.embedding_dim
def get_word_embedding_dimension(self):
"""Return the dimension of the word/token embeddings."""
return self.embedding_dim |