File size: 6,112 Bytes
4b62d89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import torch.nn.functional as F
from sentence_transformers.models import Transformer as BaseTransformer

class EmbeddingModel(BaseTransformer):
    """Wrapper model for extracting embeddings from a causal language model using SentenceTransformer framework."""
    
    def __init__(self, model_name_or_path=None, base_model=None, tokenizer=None, max_seq_length=512, pooling="last", **kwargs):
        """
        Initialize the embedding model with a base model, tokenizer, and pooling strategy.
        
        Args:
            model_name_or_path: HuggingFace model name or path (used by SentenceTransformer)
            base_model: Pre-initialized model (prioritized over model_name_or_path)
            tokenizer: Tokenizer to use (must be provided if base_model is provided)
            max_seq_length: Maximum sequence length
            pooling: Pooling strategy ("mean" or "last" token pooling)
        """
        # If we're given a base_model directly, use that instead of loading from model_name_or_path
        if base_model is not None:
            if tokenizer is None:
                raise ValueError("If base_model is provided, tokenizer must also be provided")
            
            # Skip the normal initialization - we'll do that manually
            # This just initializes the parent nn.Module - we'll handle the rest ourselves
            super(BaseTransformer, self).__init__()
            
            # Set up the model configuration manually
            self.config = base_model.config
            self.max_seq_length = max_seq_length
            self.auto_model = base_model
            self._tokenizer = tokenizer
            self.tokenizer = tokenizer  # For compatibility 
            self.do_lower_case = getattr(tokenizer, "do_lower_case", False)
            
            # For certain models (like Llama), ensure that padding_idx is set correctly
            self.padding_idx = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
            
            # Additional attributes from BaseTransformer
            self.name = 'Transformer'
            self.backend = "huggingface_transformers"  # Default backend
        else:
            # Use standard initialization from BaseTransformer
            super().__init__(model_name_or_path=model_name_or_path, max_seq_length=max_seq_length, **kwargs)
        
        self.pooling = pooling
        self.embedding_dim = self.auto_model.config.hidden_size
        
        # Remove lm_head if it exists to save memory
        if hasattr(self.auto_model, "lm_head"):
            delattr(self.auto_model, "lm_head")

    def forward(self, features):
        """
        Forward pass through the model to get embeddings.
        Adapted to work with SentenceTransformer's expected format.
        
        Args:
            features: Dictionary with 'input_ids', 'attention_mask', etc.
            
        Returns:
            Dictionary with embeddings
        """
        input_ids = features['input_ids']
        attention_mask = features['attention_mask']
        
        # Get the model outputs
        outputs = self.auto_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )
        
        # Extract embeddings using the specified pooling
        hidden_states = outputs.hidden_states[-1]
        
        # Get embeddings using our pooling method
        embeddings = self._get_embeddings(
            hidden_states, 
            input_ids, 
            self._tokenizer.eos_token_id if hasattr(self, '_tokenizer') else self.tokenizer.eos_token_id
        )
        
        # Return in the format expected by SentenceTransformer
        return {'token_embeddings': hidden_states, 'sentence_embedding': embeddings}

    def _get_embeddings(self, hidden_states, input_ids, eos_token_id, pooling=None):
        """Extract embeddings from hidden states using the specified pooling strategy."""
        # Extract embeddings using the pooling strategy
        if pooling is None:
            pooling = self.pooling

        batch_size = input_ids.shape[0]
        hidden_dim = hidden_states.size(-1)
        embeddings = torch.zeros(batch_size, hidden_dim, device=hidden_states.device)
        
        tokenizer = self._tokenizer if hasattr(self, '_tokenizer') else self.tokenizer

        if pooling == "mean":
            attention_mask = (input_ids != tokenizer.pad_token_id).float()
            sum_embeddings = torch.sum(
                hidden_states * attention_mask.unsqueeze(-1), dim=1
            )
            input_mask_sum = torch.sum(attention_mask, dim=1).unsqueeze(-1)
            input_mask_sum = torch.clamp(input_mask_sum, min=1e-9)
            embeddings = sum_embeddings / input_mask_sum
        else:
            eos_positions = (input_ids == eos_token_id).nonzero(as_tuple=True)
            batch_indices = eos_positions[0]
            token_positions = eos_positions[1]
            has_eos = torch.zeros(
                batch_size, dtype=torch.bool, device=hidden_states.device
            )
            has_eos[batch_indices] = True
            unique_batch_indices = batch_indices.unique()
            for i in unique_batch_indices:
                idx = (batch_indices == i).nonzero(as_tuple=True)[0][0]
                embeddings[i] = hidden_states[i, token_positions[idx]]

            non_eos_indices = (~has_eos).nonzero(as_tuple=True)[0]
            if len(non_eos_indices) > 0:
                for i in non_eos_indices:
                    mask = (input_ids[i] != tokenizer.pad_token_id).nonzero(
                        as_tuple=True
                    )[0]
                    embeddings[i] = hidden_states[i, mask[-1]]

        return embeddings
        
    def get_sentence_embedding_dimension(self):
        """Return the dimension of the sentence embeddings."""
        return self.embedding_dim
        
    def get_word_embedding_dimension(self):
        """Return the dimension of the word/token embeddings."""
        return self.embedding_dim