bowang0911 commited on
Commit
29fee2c
·
1 Parent(s): aedbcae

feat: add custom transformer layer and contextual pooling

Browse files
Files changed (2) hide show
  1. contextual_pooling.py +192 -0
  2. custom_st.py +199 -0
contextual_pooling.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Pooling module for Late Chunking contextual embeddings.
3
+
4
+ This module extracts chunk-level embeddings from concatenated token embeddings.
5
+ Works together with custom_st.py to enable late chunking.
6
+
7
+ Usage:
8
+ Place this file alongside custom_st.py in your model directory.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ import os
15
+ from typing import Any
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+
21
+ def mean_pooling_chunked(
22
+ token_embeddings: torch.Tensor, attention_mask: torch.Tensor
23
+ ) -> torch.Tensor:
24
+ """
25
+ Vectorized mean pooling for chunked embeddings.
26
+
27
+ Args:
28
+ token_embeddings: (batch_size, num_chunks, chunk_max_seq_length, hidden_dim)
29
+ attention_mask: (batch_size, num_chunks, chunk_max_seq_length)
30
+
31
+ Returns:
32
+ pooled: (batch_size, num_chunks, hidden_dim)
33
+ """
34
+ # Expand mask to match embedding dimension
35
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
36
+
37
+ # Sum embeddings across chunk_max_seq_length dimension (dim=2)
38
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=2)
39
+
40
+ # Sum mask across chunk_max_seq_length dimension
41
+ sum_mask = torch.clamp(input_mask_expanded.sum(dim=2), min=1e-9)
42
+
43
+ return sum_embeddings / sum_mask
44
+
45
+
46
+ class Pooling(nn.Module):
47
+ """
48
+ Pooling module for late chunking that extracts individual chunk embeddings.
49
+
50
+ This module reads 'num_chunks_list' from features (set by custom_st.py Transformer)
51
+ and extracts chunk embeddings by reshaping and mean pooling within each chunk.
52
+
53
+ Args:
54
+ chunk_max_seq_length: Maximum sequence length for each chunk in tokens.
55
+ Must match the value in custom_st.py Transformer.
56
+
57
+ Config keys:
58
+ - chunk_max_seq_length: The fixed token length per chunk
59
+
60
+ Example:
61
+ >>> pooling = Pooling(chunk_max_seq_length=128)
62
+ >>> # Input: token_embeddings from 3 chunks of 128 tokens each
63
+ >>> features = {
64
+ ... 'token_embeddings': torch.randn(1, 384, 768), # batch=1, seq_len=384, hidden=768
65
+ ... 'attention_mask': torch.ones(1, 384),
66
+ ... 'num_chunks_list': [3] # Set by Transformer.tokenize()
67
+ ... }
68
+ >>> output = pooling(features)
69
+ >>> output['chunk_embeddings'].shape
70
+ torch.Size([1, 3, 768]) # batch=1, num_chunks=3, hidden=768
71
+ """
72
+
73
+ def __init__(self, chunk_max_seq_length: int = 128, **kwargs):
74
+ super().__init__()
75
+ self.config_keys = ["chunk_max_seq_length"]
76
+ self.chunk_max_seq_length = chunk_max_seq_length
77
+
78
+ def forward(
79
+ self,
80
+ features: dict[str, torch.Tensor],
81
+ **kwargs,
82
+ ) -> dict[str, torch.Tensor]:
83
+ """
84
+ Forward pass that extracts chunk embeddings via mean pooling.
85
+
86
+ Args:
87
+ features: Dictionary with 'token_embeddings', 'attention_mask', and 'num_chunks_list'
88
+ **kwargs: Additional arguments (ignored)
89
+
90
+ Returns:
91
+ Dictionary with 'chunk_embeddings' of shape (batch_size, num_chunks, hidden_dim)
92
+ """
93
+ token_embeddings = features["token_embeddings"]
94
+ attention_mask = features["attention_mask"]
95
+ num_chunks_list = features.get("num_chunks_list")
96
+
97
+ if num_chunks_list is None:
98
+ raise ValueError(
99
+ "num_chunks_list not found in features. "
100
+ "Make sure you're using the custom_st.py Transformer module."
101
+ )
102
+
103
+ # Extract chunk embeddings for each document in the batch
104
+ batch_chunk_embeddings = []
105
+
106
+ for i, num_chunks in enumerate(num_chunks_list):
107
+ doc_token_embeddings = token_embeddings[i : i + 1] # Keep batch dim
108
+ doc_attention_mask = attention_mask[i : i + 1]
109
+
110
+ chunk_embeddings = self._extract_chunk_embeddings(
111
+ doc_token_embeddings, doc_attention_mask, num_chunks
112
+ )
113
+ batch_chunk_embeddings.append(chunk_embeddings[0]) # Remove batch dim
114
+
115
+ features["chunk_embeddings"] = batch_chunk_embeddings
116
+
117
+ return features
118
+
119
+ def _extract_chunk_embeddings(
120
+ self,
121
+ token_embeddings: torch.Tensor,
122
+ attention_mask: torch.Tensor,
123
+ num_chunks: int,
124
+ ) -> torch.Tensor:
125
+ """
126
+ Extract chunk embeddings from concatenated token embeddings.
127
+
128
+ Args:
129
+ token_embeddings: (1, seq_len, hidden_dim)
130
+ attention_mask: (1, seq_len)
131
+ num_chunks: Number of chunks
132
+
133
+ Returns:
134
+ chunk_embeddings: (1, num_chunks, hidden_dim)
135
+ """
136
+ batch_size, seq_len, hidden_dim = token_embeddings.shape
137
+
138
+ # Verify that seq_len matches num_chunks * chunk_max_seq_length
139
+ expected_seq_len = num_chunks * self.chunk_max_seq_length
140
+ if seq_len != expected_seq_len:
141
+ raise ValueError(
142
+ f"Sequence length {seq_len} does not match num_chunks * chunk_max_seq_length "
143
+ f"({num_chunks} * {self.chunk_max_seq_length} = {expected_seq_len})"
144
+ )
145
+
146
+ # Reshape to (batch_size, num_chunks, chunk_max_seq_length, hidden_dim)
147
+ token_embeddings_chunked = token_embeddings.view(
148
+ batch_size, num_chunks, self.chunk_max_seq_length, hidden_dim
149
+ )
150
+
151
+ # Reshape attention mask to (batch_size, num_chunks, chunk_max_seq_length)
152
+ attention_mask_chunked = attention_mask.view(
153
+ batch_size, num_chunks, self.chunk_max_seq_length
154
+ )
155
+
156
+ # Apply mean pooling to all chunks at once
157
+ chunk_embeddings = mean_pooling_chunked(
158
+ token_embeddings_chunked, attention_mask_chunked
159
+ )
160
+
161
+ return chunk_embeddings
162
+
163
+ def get_sentence_embedding_dimension(self) -> int | None:
164
+ """
165
+ Returns None since the embedding dimension comes from the transformer.
166
+ """
167
+ return None
168
+
169
+ def get_config_dict(self) -> dict[str, Any]:
170
+ """
171
+ Returns the config dictionary for serialization.
172
+ """
173
+ return {key: getattr(self, key) for key in self.config_keys}
174
+
175
+ def save(self, output_path: str, safe_serialization: bool = True) -> None:
176
+ """
177
+ Save the module configuration.
178
+ """
179
+ with open(
180
+ os.path.join(output_path, "config.json"), "w", encoding="utf-8"
181
+ ) as f:
182
+ json.dump(self.get_config_dict(), f, indent=2)
183
+
184
+ @staticmethod
185
+ def load(input_path: str) -> "Pooling":
186
+ """
187
+ Load the module from a directory.
188
+ """
189
+ config_path = os.path.join(input_path, "config.json")
190
+ with open(config_path, encoding="utf-8") as f:
191
+ config = json.load(f)
192
+ return Pooling(**config)
custom_st.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Sentence Transformer module for Late Chunking contextual embeddings.
3
+
4
+ This module ONLY accepts nested lists of chunks: list[list[str]]
5
+ All chunks are encoded together to preserve positional context (late chunking).
6
+
7
+ Usage:
8
+ Place this file alongside contextual_pooling.py in your model directory.
9
+ Load with: SentenceTransformer("model-path", trust_remote_code=True)
10
+
11
+ documents = [
12
+ ["chunk 1", "chunk 2", "chunk 3"],
13
+ ["chunk 1", "chunk 2"]
14
+ ]
15
+ embeddings = model.encode(documents)
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import json
21
+ import os
22
+ from typing import Any, Dict, List, Optional
23
+
24
+ import torch
25
+ from torch import nn
26
+ from transformers import AutoConfig, AutoModel, AutoTokenizer
27
+
28
+
29
+ class Transformer(nn.Module):
30
+ """
31
+ Custom Transformer wrapper for late chunking contextual embeddings.
32
+
33
+ This class handles tokenization of nested list inputs (documents with chunks)
34
+ and concatenates chunks with fixed chunk_max_seq_length for late chunking.
35
+
36
+ Args:
37
+ model_name_or_path: Hugging Face model name or path
38
+ chunk_max_seq_length: Maximum sequence length for each chunk in tokens.
39
+ Each chunk will be padded or truncated to exactly this many tokens.
40
+ model_args: Keyword arguments passed to AutoModel.from_pretrained()
41
+ tokenizer_args: Keyword arguments passed to AutoTokenizer.from_pretrained()
42
+ config_args: Keyword arguments passed to AutoConfig.from_pretrained()
43
+ cache_dir: Cache directory for Hugging Face models
44
+ tokenizer_name_or_path: Tokenizer name or path (defaults to model_name_or_path)
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ model_name_or_path: str,
50
+ chunk_max_seq_length: int = 128,
51
+ model_args: Optional[Dict[str, Any]] = None,
52
+ tokenizer_args: Optional[Dict[str, Any]] = None,
53
+ config_args: Optional[Dict[str, Any]] = None,
54
+ cache_dir: Optional[str] = None,
55
+ tokenizer_name_or_path: Optional[str] = None,
56
+ **kwargs,
57
+ ) -> None:
58
+ super().__init__()
59
+
60
+ self.config_keys = ["chunk_max_seq_length"]
61
+ self.chunk_max_seq_length = chunk_max_seq_length
62
+
63
+ if model_args is None:
64
+ model_args = {}
65
+ if tokenizer_args is None:
66
+ tokenizer_args = {}
67
+ if config_args is None:
68
+ config_args = {}
69
+
70
+ # Load config and model
71
+ config = AutoConfig.from_pretrained(
72
+ model_name_or_path, **config_args, cache_dir=cache_dir
73
+ )
74
+ self.auto_model = AutoModel.from_pretrained(
75
+ model_name_or_path, config=config, cache_dir=cache_dir, **model_args
76
+ )
77
+
78
+ # Load tokenizer
79
+ self.tokenizer = AutoTokenizer.from_pretrained(
80
+ tokenizer_name_or_path if tokenizer_name_or_path else model_name_or_path,
81
+ cache_dir=cache_dir,
82
+ **tokenizer_args,
83
+ )
84
+
85
+ def forward(
86
+ self, features: Dict[str, torch.Tensor]
87
+ ) -> Dict[str, torch.Tensor]:
88
+ """
89
+ Forward pass through the transformer model.
90
+
91
+ Args:
92
+ features: Dictionary with 'input_ids' and 'attention_mask'
93
+
94
+ Returns:
95
+ Dictionary with 'token_embeddings' and other transformer outputs
96
+ """
97
+ output_states = self.auto_model(
98
+ input_ids=features["input_ids"],
99
+ attention_mask=features["attention_mask"],
100
+ return_dict=True,
101
+ )
102
+
103
+ # Get token embeddings (last hidden state)
104
+ features["token_embeddings"] = output_states.last_hidden_state
105
+
106
+ return features
107
+
108
+ def tokenize(
109
+ self, texts: List[List[str]], padding: bool = True
110
+ ) -> Dict[str, torch.Tensor]:
111
+ """
112
+ Tokenize nested list of chunks for late chunking.
113
+
114
+ Args:
115
+ texts: list[list[str]] - list of documents, each document is a list of chunk texts
116
+ padding: Padding strategy (always uses max_length for chunks)
117
+
118
+ Returns:
119
+ Dictionary with 'input_ids', 'attention_mask', and 'num_chunks_list'
120
+ """
121
+ if not texts:
122
+ raise ValueError("Input texts cannot be empty")
123
+
124
+ if not isinstance(texts[0], list):
125
+ raise ValueError(
126
+ "This model only accepts nested lists: list[list[str]]. "
127
+ "Expected format: [['chunk1', 'chunk2'], ['chunk1']]"
128
+ )
129
+
130
+ all_input_ids = []
131
+ all_attention_masks = []
132
+ num_chunks_list = []
133
+
134
+ for i, doc_chunks in enumerate(texts):
135
+ if not doc_chunks:
136
+ raise ValueError(
137
+ f"Document at index {i} has no chunks. "
138
+ "Each document must have at least one chunk."
139
+ )
140
+
141
+ # Tokenize each chunk to fixed size
142
+ chunk_input_ids = []
143
+ chunk_attention_masks = []
144
+
145
+ for chunk in doc_chunks:
146
+ chunk_tokens = self.tokenizer(
147
+ chunk,
148
+ max_length=self.chunk_max_seq_length,
149
+ truncation=True,
150
+ padding="max_length",
151
+ return_tensors=None,
152
+ )
153
+ chunk_input_ids.extend(chunk_tokens["input_ids"])
154
+ chunk_attention_masks.extend(chunk_tokens["attention_mask"])
155
+
156
+ all_input_ids.append(chunk_input_ids)
157
+ all_attention_masks.append(chunk_attention_masks)
158
+ num_chunks_list.append(len(doc_chunks))
159
+
160
+ # Convert to tensors
161
+ return {
162
+ "input_ids": torch.tensor(all_input_ids, dtype=torch.long),
163
+ "attention_mask": torch.tensor(all_attention_masks, dtype=torch.long),
164
+ "num_chunks_list": num_chunks_list, # Store for pooling
165
+ }
166
+
167
+ def get_word_embedding_dimension(self) -> int:
168
+ """Returns the embedding dimension."""
169
+ return self.auto_model.config.hidden_size
170
+
171
+ def save(self, output_path: str, safe_serialization: bool = True) -> None:
172
+ """Save the model, tokenizer, and configuration."""
173
+ self.auto_model.save_pretrained(
174
+ output_path, safe_serialization=safe_serialization
175
+ )
176
+ self.tokenizer.save_pretrained(output_path)
177
+
178
+ # Save custom config
179
+ with open(
180
+ os.path.join(output_path, "sentence_bert_config.json"), "w", encoding="utf-8"
181
+ ) as f:
182
+ json.dump(
183
+ {key: getattr(self, key) for key in self.config_keys},
184
+ f,
185
+ indent=2,
186
+ )
187
+
188
+ @staticmethod
189
+ def load(input_path: str) -> "Transformer":
190
+ """Load the model from a directory."""
191
+ # Load config
192
+ config_path = os.path.join(input_path, "sentence_bert_config.json")
193
+ if os.path.exists(config_path):
194
+ with open(config_path, encoding="utf-8") as f:
195
+ config = json.load(f)
196
+ else:
197
+ config = {}
198
+
199
+ return Transformer(model_name_or_path=input_path, **config)