| from typing import Any, Dict, List, Optional, Union |
| from dataclasses import dataclass |
| import unicodedata |
| from PIL import Image |
| import logging |
|
|
| from peft import PeftModel |
| import torch |
| import torch.nn.functional as F |
|
|
| from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5PreTrainedModel, Qwen3_5Model |
| from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5Config |
|
|
| from qwen_vl_utils.vision_process import process_vision_info |
|
|
| from transformers import AutoProcessor |
| from transformers.modeling_outputs import ModelOutput |
| from transformers.utils import TransformersKwargs |
| from transformers.processing_utils import Unpack |
| from transformers.cache_utils import Cache |
|
|
|
|
| MAX_LENGTH = 2048 |
| IMAGE_BASE_FACTOR = 16 |
| IMAGE_FACTOR = IMAGE_BASE_FACTOR * 2 |
| MIN_PIXELS = 4 * IMAGE_FACTOR * IMAGE_FACTOR |
| MAX_PIXELS = 1024 * IMAGE_FACTOR * IMAGE_FACTOR |
| PAD_TOKEN = "<|endoftext|>" |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class ColQwen3_5ForEmbeddingOutput(ModelOutput): |
| """Output of ColQwen3_5ForEmbedding. |
| |
| Args: |
| hidden_states (`torch.FloatTensor`): Last hidden state of the model [B, N, D]. |
| attention_mask (`torch.Tensor`): Attention mask [B, N]. |
| attentions (`tuple`, optional): Per-layer attention tensors when |
| forward() is called with output_attentions=True. Each entry is |
| [B, H, N, N] for full-attention layers or None for DeltaNet layers. |
| """ |
| hidden_states: Optional[torch.FloatTensor] = None |
| attention_mask: Optional[torch.Tensor] = None |
| attentions: Optional[tuple] = None |
| |
|
|
| class ColQwen3_5ForEmbedding(Qwen3_5PreTrainedModel): |
| _checkpoint_conversion_mapping = {} |
| accepts_loss_kwargs = False |
| config: Qwen3_5Config |
| |
| def __init__(self, config): |
| super().__init__(config) |
| self.model = Qwen3_5Model(config) |
| self.post_init() |
| |
| def get_input_embeddings(self): |
| return self.model.get_input_embeddings() |
| |
| def set_input_embeddings(self, value): |
| self.model.set_input_embeddings(value) |
| |
| def get_decoder(self): |
| return self.model.get_decoder() |
| |
| def set_decoder(self, decoder): |
| self.model.set_decoder(decoder) |
| |
| def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): |
| return self.model.get_image_features(pixel_values, image_grid_thw) |
| |
| @property |
| def language_model(self): |
| return self.model.language_model |
| |
| @property |
| def vision_model(self): |
| return self.model.visual |
| |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| image_grid_thw: Optional[torch.LongTensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| output_attentions: bool = False, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> Union[tuple, ColQwen3_5ForEmbeddingOutput]: |
| r""" |
| Returns: |
| ColQwen3_5ForEmbeddingOutput with fields: |
| - `hidden_states` ([B, N, D]): Last hidden state of the model. |
| - `attention_mask` ([B, N]): Attention mask. |
| - `attentions` (tuple | None): Per-layer attention tensors when |
| output_attentions=True. GQA layers → [B, H, N, N]; DeltaNet |
| layers (Qwen3.5 hybrid) → None. |
| """ |
| outputs = self.model( |
| input_ids=input_ids, |
| pixel_values=pixel_values, |
| image_grid_thw=image_grid_thw, |
| position_ids=position_ids, |
| attention_mask=attention_mask, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| cache_position=cache_position, |
| output_attentions=output_attentions, |
| **kwargs, |
| ) |
|
|
| return ColQwen3_5ForEmbeddingOutput( |
| hidden_states=outputs.last_hidden_state, |
| attention_mask=attention_mask, |
| attentions=outputs.attentions if output_attentions else None, |
| ) |
| |
| |
| class ColQwen3_5Embedder: |
| def __init__( |
| self, |
| model_name_or_path: str = "Qwen/Qwen3.5-0.8B", |
| lora_checkpoint: Optional[str] = None, |
| max_length: int = MAX_LENGTH, |
| min_pixels: int = MIN_PIXELS, |
| max_pixels: int = MAX_PIXELS, |
| default_instruction: str = "Represent the user's input.", |
| embed_dim: Optional[int] = None, |
| **kwargs, |
| ): |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.max_length = max_length |
| self.min_pixels = min_pixels |
| self.max_pixels = max_pixels |
| self.embed_dim = embed_dim |
| |
| self.default_instruction = default_instruction |
| |
| self.model = ColQwen3_5ForEmbedding.from_pretrained(model_name_or_path).to(device) |
| |
| if lora_checkpoint: |
| self.model = PeftModel.from_pretrained(self.model, lora_checkpoint) |
| self.model = self.model.to(torch.bfloat16) |
| |
| self.processor = AutoProcessor.from_pretrained(model_name_or_path, padding_side="right") |
| |
| self.model.eval() |
| |
| @torch.no_grad() |
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: |
| outputs = self.model(**inputs) |
| return { |
| "embeddings": outputs.hidden_states, |
| "attention_mask": outputs.attention_mask |
| } |
| |
| def truncate_tokens(self, token_ids: List[int], max_length: int) -> List[int]: |
| if len(token_ids) <= max_length: |
| return token_ids |
|
|
| special_token_ids = set(self.processor.tokenizer.all_special_ids) |
| num_special = sum(1 for token_idx in token_ids if token_idx in special_token_ids) |
| num_non_special_to_keep = max_length - num_special |
|
|
| final_token_ids = [] |
| non_special_kept_count = 0 |
|
|
| for token_idx in token_ids: |
| if token_idx in special_token_ids: |
| final_token_ids.append(token_idx) |
| elif non_special_kept_count < num_non_special_to_keep: |
| final_token_ids.append(token_idx) |
| non_special_kept_count += 1 |
| |
| return final_token_ids |
| |
| def format_model_input( |
| self, text: Optional[str] = None, |
| image: Optional[Union[str, Image.Image]] = None, |
| instruction: Optional[str] = None, |
| ) -> List[Dict]: |
| |
| |
| if instruction: |
| instruction = instruction.strip() |
| if instruction and not unicodedata.category(instruction[-1]).startswith('P'): |
| instruction = instruction + '.' |
|
|
| content = [] |
| conversation = [ |
| {"role": "system", "content": [{"type": "text", "text": instruction or self.default_instruction}]}, |
| {"role": "user", "content": content} |
| ] |
| |
| |
| if not text and not image: |
| content.append({'type': 'text', 'text': "NULL"}) |
| return conversation |
| |
| if image: |
| image_content = None |
| if isinstance(image, Image.Image): |
| image_content = image |
| elif isinstance(image, str): |
| image_content = image if image.startswith(('http', 'oss')) else 'file://' + image |
| else: |
| raise TypeError(f"Unrecognized image type: {type(image)}") |
|
|
| |
| if image_content: |
| content.append({ |
| 'type': 'image', 'image': image_content, |
| "min_pixels": self.min_pixels, |
| "max_pixels": self.max_pixels |
| }) |
|
|
| if text: |
| content.append({'type': 'text', 'text': text}) |
| |
| return conversation |
| |
| def _preprocess_inputs(self, conversations: List[List[Dict]]) -> Dict[str, torch.Tensor]: |
| text = self.processor.apply_chat_template( |
| conversations, add_generation_prompt=True, tokenize=False |
| ) |
|
|
| try: |
| images, video_inputs, video_kwargs = process_vision_info( |
| conversations, image_patch_size=16, |
| return_video_metadata=True, return_video_kwargs=True |
| ) |
| |
| except Exception as e: |
| logger.error(f"Error in processing vision info: {e}") |
| images = None |
| video_inputs = None |
| video_kwargs = {'do_sample_frames': False} |
| text = self.processor.apply_chat_template( |
| [{'role': 'user', 'content': [{'type': 'text', 'text': 'NULL'}]}], |
| add_generation_prompt=True, tokenize=False |
| ) |
|
|
| if video_inputs is not None: |
| videos, video_metadata = zip(*video_inputs) |
| videos = list(videos) |
| video_metadata = list(video_metadata) |
| else: |
| videos, video_metadata = None, None |
|
|
| inputs = self.processor( |
| text=text, images=images, videos=videos, video_metadata=video_metadata, truncation=True, |
| max_length=self.max_length, padding=True, do_resize=False, return_tensors='pt', |
| **video_kwargs |
| ) |
| return inputs |
|
|
| @staticmethod |
| def _pooling_last(hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
| flipped_tensor = attention_mask.flip(dims=[1]) |
| last_one_positions = flipped_tensor.argmax(dim=1) |
| col = attention_mask.shape[1] - last_one_positions - 1 |
| row = torch.arange(hidden_state.shape[0], device=hidden_state.device) |
| return hidden_state[row, col] |
| |
| def _truncate_dimensions(self, embeddings: torch.Tensor) -> torch.Tensor: |
| |
| if self.embed_dim is not None and embeddings.shape[-1] > self.embed_dim: |
| return embeddings[:, :, :self.embed_dim] |
| return embeddings |
|
|
| |
| def process(self, inputs: List[Dict[str, Any]], normalize: bool = True, pooling: bool = False) -> tuple: |
| conversations = [self.format_model_input( |
| text=ele.get('text'), |
| image=ele.get('image'), |
| instruction=ele.get('instruction'), |
| ) for ele in inputs] |
|
|
| processed_inputs = self._preprocess_inputs(conversations) |
| processed_inputs = {k: v.to(self.model.device) for k, v in processed_inputs.items()} |
|
|
| outputs = self.forward(processed_inputs) |
| |
| embeddings = outputs['embeddings'] |
| attention_mask = outputs['attention_mask'] |
|
|
| if pooling: |
| embeddings = self._pooling_last(embeddings, attention_mask) |
| if normalize: |
| embeddings = F.normalize(embeddings, p=2, dim=-1) |
| |
| return embeddings, attention_mask |
|
|
| else: |
| embeddings = self._truncate_dimensions(embeddings) |
| if normalize: |
| embeddings = F.normalize(embeddings, p=2, dim=-1) |
|
|
| return embeddings, attention_mask |
|
|
| @staticmethod |
| def score_maxsim( |
| query_embeddings: torch.Tensor, |
| doc_embeddings: torch.Tensor, |
| query_mask: torch.Tensor, |
| doc_mask: torch.Tensor, |
| device: str = "cuda" if torch.cuda.is_available() else "cpu" |
| ) -> torch.Tensor: |
| """ |
| Compute MaxSim scores between queries and documents (multi-vector). |
| |
| Args: |
| query_embeddings: (Q, Lq, D) — multi-vector query embeddings (normalized) |
| doc_embeddings: (D_count, Ld, D) — multi-vector doc embeddings (normalized) |
| query_mask: (Q, Lq) — attention mask for queries |
| doc_mask: (D_count, Ld) — attention mask for docs |
| |
| Returns: |
| scores: (Q, D_count) — MaxSim similarity matrix |
| """ |
| doc_embeddings = doc_embeddings.to(device) |
| query_mask = query_mask.to(device) |
| doc_mask = doc_mask.to(device) |
|
|
| sim = torch.einsum("qid,njd->qinj", query_embeddings, doc_embeddings) |
|
|
| doc_pad_mask = ~doc_mask.bool() |
| sim = sim.masked_fill(doc_pad_mask.unsqueeze(0).unsqueeze(0), float("-inf")) |
|
|
| query_pad_mask = ~query_mask.bool() |
| sim = sim.masked_fill(query_pad_mask.unsqueeze(2).unsqueeze(-1), 0.0) |
|
|
| scores = sim.max(dim=-1).values |
| scores = scores.sum(dim=1) |
| |
| return scores |
|
|
| @staticmethod |
| def score_dense( |
| query_embeddings: torch.Tensor, |
| doc_embeddings: torch.Tensor, |
| device: str = "cuda" if torch.cuda.is_available() else "cpu" |
| ) -> torch.Tensor: |
| """ |
| Compute dot-product scores between pooled query and doc embeddings. |
| |
| Args: |
| query_embeddings: (Q, D) — pooled + normalized query embeddings |
| doc_embeddings: (D_count, D) — pooled + normalized doc embeddings |
| |
| Returns: |
| scores: (Q, D_count) |
| """ |
| doc_embeddings = doc_embeddings.to(device) |
| query_embeddings = query_embeddings.to(device) |
| return torch.matmul(query_embeddings, doc_embeddings.T) |
| |