leo-vnuuet commited on
Commit
c14f5af
·
1 Parent(s): feb9ad6

update README, usage implementation

Browse files
embedder/__pycache__/colqwen3.5_embedder.cpython-38.pyc ADDED
Binary file (12 kB). View file
 
embedder/colqwen3_5_embedder.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Union
2
+ from dataclasses import dataclass
3
+ import unicodedata
4
+ from PIL import Image
5
+ import logging
6
+
7
+ from peft import PeftModel
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5PreTrainedModel, Qwen3_5Model
12
+ from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5Config
13
+
14
+ from qwen_vl_utils.vision_process import process_vision_info
15
+
16
+ from transformers import AutoProcessor
17
+ from transformers.modeling_outputs import ModelOutput
18
+ from transformers.utils import TransformersKwargs
19
+ from transformers.processing_utils import Unpack
20
+ from transformers.cache_utils import Cache
21
+
22
+
23
+ MAX_LENGTH = 2048
24
+ IMAGE_BASE_FACTOR = 16
25
+ IMAGE_FACTOR = IMAGE_BASE_FACTOR * 2
26
+ MIN_PIXELS = 4 * IMAGE_FACTOR * IMAGE_FACTOR # 4096
27
+ MAX_PIXELS = 1024 * IMAGE_FACTOR * IMAGE_FACTOR # 1048576
28
+ PAD_TOKEN = "<|endoftext|>"
29
+
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ @dataclass
35
+ class ColQwen3_5ForEmbeddingOutput(ModelOutput):
36
+ """Output of ColQwen3_5ForEmbedding.
37
+
38
+ Args:
39
+ hidden_states (`torch.FloatTensor`): Last hidden state of the model [B, N, D].
40
+ attention_mask (`torch.Tensor`): Attention mask [B, N].
41
+ attentions (`tuple`, optional): Per-layer attention tensors when
42
+ forward() is called with output_attentions=True. Each entry is
43
+ [B, H, N, N] for full-attention layers or None for DeltaNet layers.
44
+ """
45
+ hidden_states: Optional[torch.FloatTensor] = None
46
+ attention_mask: Optional[torch.Tensor] = None
47
+ attentions: Optional[tuple] = None
48
+
49
+
50
+ class ColQwen3_5ForEmbedding(Qwen3_5PreTrainedModel):
51
+ _checkpoint_conversion_mapping = {}
52
+ accepts_loss_kwargs = False
53
+ config: Qwen3_5Config
54
+
55
+ def __init__(self, config):
56
+ super().__init__(config)
57
+ self.model = Qwen3_5Model(config)
58
+ self.post_init()
59
+
60
+ def get_input_embeddings(self):
61
+ return self.model.get_input_embeddings()
62
+
63
+ def set_input_embeddings(self, value):
64
+ self.model.set_input_embeddings(value)
65
+
66
+ def get_decoder(self):
67
+ return self.model.get_decoder()
68
+
69
+ def set_decoder(self, decoder):
70
+ self.model.set_decoder(decoder)
71
+
72
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
73
+ return self.model.get_image_features(pixel_values, image_grid_thw)
74
+
75
+ @property
76
+ def language_model(self):
77
+ return self.model.language_model
78
+
79
+ @property
80
+ def vision_model(self):
81
+ return self.model.visual
82
+
83
+ def forward(
84
+ self,
85
+ input_ids: torch.LongTensor = None,
86
+ attention_mask: Optional[torch.Tensor] = None,
87
+ position_ids: Optional[torch.LongTensor] = None,
88
+ past_key_values: Optional[Cache] = None,
89
+ inputs_embeds: Optional[torch.FloatTensor] = None,
90
+ pixel_values: Optional[torch.Tensor] = None,
91
+ image_grid_thw: Optional[torch.LongTensor] = None,
92
+ cache_position: Optional[torch.LongTensor] = None,
93
+ logits_to_keep: Union[int, torch.Tensor] = 0,
94
+ output_attentions: bool = False,
95
+ **kwargs: Unpack[TransformersKwargs], # type: ignore
96
+ ) -> Union[tuple, ColQwen3_5ForEmbeddingOutput]:
97
+ r"""
98
+ Returns:
99
+ ColQwen3_5ForEmbeddingOutput with fields:
100
+ - `hidden_states` ([B, N, D]): Last hidden state of the model.
101
+ - `attention_mask` ([B, N]): Attention mask.
102
+ - `attentions` (tuple | None): Per-layer attention tensors when
103
+ output_attentions=True. GQA layers → [B, H, N, N]; DeltaNet
104
+ layers (Qwen3.5 hybrid) → None.
105
+ """
106
+ outputs = self.model(
107
+ input_ids=input_ids,
108
+ pixel_values=pixel_values,
109
+ image_grid_thw=image_grid_thw,
110
+ position_ids=position_ids,
111
+ attention_mask=attention_mask,
112
+ past_key_values=past_key_values,
113
+ inputs_embeds=inputs_embeds,
114
+ cache_position=cache_position,
115
+ output_attentions=output_attentions,
116
+ **kwargs,
117
+ )
118
+
119
+ return ColQwen3_5ForEmbeddingOutput(
120
+ hidden_states=outputs.last_hidden_state,
121
+ attention_mask=attention_mask,
122
+ attentions=outputs.attentions if output_attentions else None,
123
+ )
124
+
125
+
126
+ class ColQwen3_5Embedder:
127
+ def __init__(
128
+ self,
129
+ model_name_or_path: str = "Qwen/Qwen3.5-0.8B",
130
+ lora_checkpoint: Optional[str] = None,
131
+ max_length: int = MAX_LENGTH,
132
+ min_pixels: int = MIN_PIXELS,
133
+ max_pixels: int = MAX_PIXELS,
134
+ default_instruction: str = "Represent the user's input.",
135
+ embed_dim: Optional[int] = None,
136
+ **kwargs,
137
+ ):
138
+
139
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
140
+ self.max_length = max_length
141
+ self.min_pixels = min_pixels
142
+ self.max_pixels = max_pixels
143
+ self.embed_dim = embed_dim
144
+
145
+ self.default_instruction = default_instruction
146
+
147
+ self.model = ColQwen3_5ForEmbedding.from_pretrained(model_name_or_path).to(device) # type: ignore
148
+
149
+ if lora_checkpoint:
150
+ self.model = PeftModel.from_pretrained(self.model, lora_checkpoint)
151
+ self.model = self.model.to(torch.bfloat16)
152
+
153
+ self.processor = AutoProcessor.from_pretrained(model_name_or_path, padding_side="right") # type: ignore
154
+
155
+ self.model.eval()
156
+
157
+ @torch.no_grad()
158
+ def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]:
159
+ outputs = self.model(**inputs)
160
+ return {
161
+ "embeddings": outputs.hidden_states,
162
+ "attention_mask": outputs.attention_mask
163
+ }
164
+
165
+ def truncate_tokens(self, token_ids: List[int], max_length: int) -> List[int]:
166
+ if len(token_ids) <= max_length:
167
+ return token_ids
168
+
169
+ special_token_ids = set(self.processor.tokenizer.all_special_ids)
170
+ num_special = sum(1 for token_idx in token_ids if token_idx in special_token_ids)
171
+ num_non_special_to_keep = max_length - num_special
172
+
173
+ final_token_ids = []
174
+ non_special_kept_count = 0
175
+
176
+ for token_idx in token_ids:
177
+ if token_idx in special_token_ids:
178
+ final_token_ids.append(token_idx)
179
+ elif non_special_kept_count < num_non_special_to_keep:
180
+ final_token_ids.append(token_idx)
181
+ non_special_kept_count += 1
182
+
183
+ return final_token_ids
184
+
185
+ def format_model_input(
186
+ self, text: Optional[str] = None,
187
+ image: Optional[Union[str, Image.Image]] = None,
188
+ instruction: Optional[str] = None,
189
+ ) -> List[Dict]:
190
+
191
+ # Ensure instruction ends with punctuation
192
+ if instruction:
193
+ instruction = instruction.strip()
194
+ if instruction and not unicodedata.category(instruction[-1]).startswith('P'):
195
+ instruction = instruction + '.'
196
+
197
+ content = []
198
+ conversation = [
199
+ {"role": "system", "content": [{"type": "text", "text": instruction or self.default_instruction}]},
200
+ {"role": "user", "content": content}
201
+ ]
202
+
203
+ # Add text, image content to conversation
204
+ if not text and not image:
205
+ content.append({'type': 'text', 'text': "NULL"})
206
+ return conversation
207
+
208
+ if image:
209
+ image_content = None
210
+ if isinstance(image, Image.Image):
211
+ image_content = image
212
+ elif isinstance(image, str):
213
+ image_content = image if image.startswith(('http', 'oss')) else 'file://' + image
214
+ else:
215
+ raise TypeError(f"Unrecognized image type: {type(image)}")
216
+
217
+ # Add image input details to content
218
+ if image_content:
219
+ content.append({
220
+ 'type': 'image', 'image': image_content,
221
+ "min_pixels": self.min_pixels,
222
+ "max_pixels": self.max_pixels
223
+ })
224
+
225
+ if text:
226
+ content.append({'type': 'text', 'text': text})
227
+
228
+ return conversation
229
+
230
+ def _preprocess_inputs(self, conversations: List[List[Dict]]) -> Dict[str, torch.Tensor]:
231
+ text = self.processor.apply_chat_template(
232
+ conversations, add_generation_prompt=True, tokenize=False
233
+ )
234
+
235
+ try:
236
+ images, video_inputs, video_kwargs = process_vision_info(
237
+ conversations, image_patch_size=16,
238
+ return_video_metadata=True, return_video_kwargs=True
239
+ )
240
+
241
+ except Exception as e:
242
+ logger.error(f"Error in processing vision info: {e}")
243
+ images = None
244
+ video_inputs = None
245
+ video_kwargs = {'do_sample_frames': False}
246
+ text = self.processor.apply_chat_template(
247
+ [{'role': 'user', 'content': [{'type': 'text', 'text': 'NULL'}]}],
248
+ add_generation_prompt=True, tokenize=False
249
+ )
250
+
251
+ if video_inputs is not None:
252
+ videos, video_metadata = zip(*video_inputs)
253
+ videos = list(videos)
254
+ video_metadata = list(video_metadata)
255
+ else:
256
+ videos, video_metadata = None, None
257
+
258
+ inputs = self.processor(
259
+ text=text, images=images, videos=videos, video_metadata=video_metadata, truncation=True,
260
+ max_length=self.max_length, padding=True, do_resize=False, return_tensors='pt',
261
+ **video_kwargs
262
+ )
263
+ return inputs
264
+
265
+ @staticmethod
266
+ def _pooling_last(hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
267
+ flipped_tensor = attention_mask.flip(dims=[1])
268
+ last_one_positions = flipped_tensor.argmax(dim=1)
269
+ col = attention_mask.shape[1] - last_one_positions - 1
270
+ row = torch.arange(hidden_state.shape[0], device=hidden_state.device)
271
+ return hidden_state[row, col]
272
+
273
+ def _truncate_dimensions(self, embeddings: torch.Tensor) -> torch.Tensor:
274
+ # Truncate to embed_dim if specified
275
+ if self.embed_dim is not None and embeddings.shape[-1] > self.embed_dim:
276
+ return embeddings[:, :, :self.embed_dim]
277
+ return embeddings
278
+
279
+ # Process inputs to generate normalized embeddings
280
+ def process(self, inputs: List[Dict[str, Any]], normalize: bool = True, pooling: bool = False) -> tuple:
281
+ conversations = [self.format_model_input(
282
+ text=ele.get('text'),
283
+ image=ele.get('image'),
284
+ instruction=ele.get('instruction'),
285
+ ) for ele in inputs]
286
+
287
+ processed_inputs = self._preprocess_inputs(conversations)
288
+ processed_inputs = {k: v.to(self.model.device) for k, v in processed_inputs.items()}
289
+
290
+ outputs = self.forward(processed_inputs)
291
+
292
+ embeddings = outputs['embeddings']
293
+ attention_mask = outputs['attention_mask']
294
+
295
+ if pooling:
296
+ embeddings = self._pooling_last(embeddings, attention_mask)
297
+ if normalize:
298
+ embeddings = F.normalize(embeddings, p=2, dim=-1)
299
+
300
+ return embeddings, attention_mask
301
+
302
+ else:
303
+ embeddings = self._truncate_dimensions(embeddings)
304
+ if normalize:
305
+ embeddings = F.normalize(embeddings, p=2, dim=-1)
306
+
307
+ return embeddings, attention_mask
308
+
309
+ @staticmethod
310
+ def score_maxsim(
311
+ query_embeddings: torch.Tensor,
312
+ doc_embeddings: torch.Tensor,
313
+ query_mask: torch.Tensor,
314
+ doc_mask: torch.Tensor,
315
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
316
+ ) -> torch.Tensor:
317
+ """
318
+ Compute MaxSim scores between queries and documents (multi-vector).
319
+
320
+ Args:
321
+ query_embeddings: (Q, Lq, D) — multi-vector query embeddings (normalized)
322
+ doc_embeddings: (D_count, Ld, D) — multi-vector doc embeddings (normalized)
323
+ query_mask: (Q, Lq) — attention mask for queries
324
+ doc_mask: (D_count, Ld) — attention mask for docs
325
+
326
+ Returns:
327
+ scores: (Q, D_count) — MaxSim similarity matrix
328
+ """
329
+ doc_embeddings = doc_embeddings.to(device)
330
+ query_mask = query_mask.to(device)
331
+ doc_mask = doc_mask.to(device)
332
+
333
+ sim = torch.einsum("qid,njd->qinj", query_embeddings, doc_embeddings)
334
+
335
+ doc_pad_mask = ~doc_mask.bool() # (Ndoc, Ld)
336
+ sim = sim.masked_fill(doc_pad_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
337
+
338
+ query_pad_mask = ~query_mask.bool() # (Q, Lq)
339
+ sim = sim.masked_fill(query_pad_mask.unsqueeze(2).unsqueeze(-1), 0.0)
340
+
341
+ scores = sim.max(dim=-1).values # (Q, Lq, Ndoc)
342
+ scores = scores.sum(dim=1) # (Q, Ndoc)
343
+
344
+ return scores
345
+
346
+ @staticmethod
347
+ def score_dense(
348
+ query_embeddings: torch.Tensor,
349
+ doc_embeddings: torch.Tensor,
350
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
351
+ ) -> torch.Tensor:
352
+ """
353
+ Compute dot-product scores between pooled query and doc embeddings.
354
+
355
+ Args:
356
+ query_embeddings: (Q, D) — pooled + normalized query embeddings
357
+ doc_embeddings: (D_count, D) — pooled + normalized doc embeddings
358
+
359
+ Returns:
360
+ scores: (Q, D_count)
361
+ """
362
+ doc_embeddings = doc_embeddings.to(device)
363
+ query_embeddings = query_embeddings.to(device)
364
+ return torch.matmul(query_embeddings, doc_embeddings.T)
365
+