| from typing import Callable, Literal |
| import numpy as np |
| import torch |
| from transformers import Qwen3Model |
| from transformers.cache_utils import Cache |
| from transformers.masking_utils import create_causal_mask |
| from transformers.modeling_outputs import BaseModelOutputWithPooling |
| from transformers.processing_utils import Unpack |
| from transformers.utils import TransformersKwargs |
| from .configuration import PPLXQwen3Config |
| from transformers import AutoTokenizer |
| from .st_quantize import FlexibleQuantizer |
|
|
|
|
| |
| def bidirectional_mask_function(attention_mask: torch.Tensor | None) -> Callable: |
| """ |
| This creates bidirectional attention mask. |
| """ |
|
|
| def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: |
| if attention_mask is None: |
| return torch.ones((), dtype=torch.bool) |
| return attention_mask[batch_idx, kv_idx].to(torch.bool) |
|
|
| return inner_mask |
|
|
|
|
| class PPLXQwen3Model(Qwen3Model): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
|
|
| config_class = PPLXQwen3Config |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.post_init() |
|
|
| def post_init(self): |
| super().post_init() |
| |
| for layer in self.layers: |
| layer.self_attn.is_causal = False |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| position_ids: torch.LongTensor | None = None, |
| past_key_values: Cache | None = None, |
| inputs_embeds: torch.FloatTensor | None = None, |
| use_cache: bool | None = None, |
| cache_position: torch.LongTensor | None = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> BaseModelOutputWithPooling: |
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
| input_ids = None |
|
|
| |
| dummy_cache_position = torch.arange( |
| inputs_embeds.shape[1], device=inputs_embeds.device, dtype=torch.long |
| ) |
| attention_mask = { |
| "full_attention": create_causal_mask( |
| config=self.config, |
| input_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| cache_position=dummy_cache_position, |
| past_key_values=None, |
| position_ids=position_ids, |
| or_mask_function=bidirectional_mask_function(attention_mask), |
| ) |
| } |
|
|
| outputs = super().forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
| return outputs |
|
|
|
|
| class PPLXQwen3ContextualModel(PPLXQwen3Model): |
| """ |
| Qwen3 model with contextual encoding support for late chunking. |
| |
| This model extends PPLXQwen3Model with an encode() method that supports both |
| standard encoding (list[str]) and contextual encoding (list[list[str]]) with late chunking. |
| |
| IMPORTANT: This model MUST be loaded with trust_remote_code=True: |
| |
| from transformers import AutoModel |
| |
| model = AutoModel.from_pretrained( |
| "path/to/model", |
| trust_remote_code=True # REQUIRED! |
| ) |
| |
| embeddings = model.encode([["chunk1", "chunk2"]]) |
| |
| Loading without trust_remote_code=True will fail to load this custom model class. |
| """ |
|
|
| config_class = PPLXQwen3Config |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| if not isinstance(config, PPLXQwen3Config): |
| raise TypeError( |
| f"PPLXQwen3ContextualModel requires PPLXQwen3Config, got {type(config).__name__}. " |
| f"Did you forget to load with trust_remote_code=True?" |
| ) |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path) |
| self._flexible_quantizer = FlexibleQuantizer() |
|
|
| @staticmethod |
| def mean_pooling( |
| token_embeddings: torch.Tensor, attention_mask: torch.Tensor |
| ) -> torch.Tensor: |
| """Apply mean pooling to token embeddings.""" |
| input_mask_expanded = ( |
| attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| ) |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( |
| input_mask_expanded.sum(1), min=1e-9 |
| ) |
|
|
| @torch.inference_mode() |
| def encode( |
| self, |
| documents: list[list[str]], |
| batch_size: int = 32, |
| show_progress_bar: bool = False, |
| device: str | torch.device | None = None, |
| normalize_embeddings: bool = False, |
| convert_to_numpy: bool = True, |
| quantization: Literal["int8", "binary"] = "int8", |
| ) -> list[np.ndarray] | list[torch.Tensor]: |
| """ |
| Encode documents with late chunking (contextual embeddings). |
| |
| This model is designed specifically for contextual encoding and always expects |
| documents as nested lists where each document is a list of text chunks. |
| |
| The encoding process: |
| 1. Concatenate chunks with separator tokens |
| 2. Run forward pass to get token embeddings |
| 3. Extract and pool individual chunk embeddings (late chunking) |
| 4. Apply quantization (Int8 or binary, always enabled) |
| 5. Normalize embeddings if requested (applied after quantization) |
| 6. Convert to numpy or return as tensors |
| |
| Args: |
| documents: List of documents, where each document is a list of text chunks. |
| Example: [["chunk1", "chunk2"], ["chunk1", "chunk2", "chunk3"]] |
| batch_size: Batch size for encoding |
| show_progress_bar: Show progress bar during encoding |
| device: Device to use for computation (defaults to model's device) |
| normalize_embeddings: Normalize embeddings to unit length (applied after quantization) |
| convert_to_numpy: If True, returns list[np.ndarray], otherwise list[torch.Tensor] |
| quantization: Quantization type to apply. Options: |
| - "int8": Int8 tanh quantization (default) |
| - "binary": Binary tanh quantization |
| |
| Returns: |
| List of numpy arrays or tensors (preserves document structure). |
| Each element has shape (n_chunks, hidden_dim). |
| embeddings[0].shape = (2, 1024), embeddings[1].shape = (3, 1024) |
| Output type depends on quantization method: |
| - Int8: int8 values in range [-128, 127] |
| - Binary: float values -1.0 or 1.0 |
| """ |
|
|
| if not isinstance(documents, list) or not all( |
| isinstance(doc, list) for doc in documents |
| ): |
| raise TypeError( |
| "Input 'documents' must be a list of lists of strings for contextual encoding." |
| ) |
|
|
| if quantization not in ["int8", "binary"]: |
| raise ValueError( |
| f"Unsupported quantization type: '{quantization}'. " |
| f"Supported types are: 'int8', 'binary'. " |
| f"Got: {type(quantization).__name__} = '{quantization}'" |
| ) |
|
|
| self.eval() |
|
|
| if device is None: |
| device = next(self.parameters()).device |
|
|
| all_embeddings = [] |
|
|
| range_iter = range(0, len(documents), batch_size) |
| if show_progress_bar: |
| try: |
| from tqdm import tqdm |
|
|
| range_iter = tqdm(range_iter, desc="Encoding documents") |
| except ImportError: |
| pass |
|
|
| for i in range_iter: |
| batch_docs = documents[i : i + batch_size] |
|
|
| doc_strings = [ |
| self.tokenizer.sep_token.join(chunks) for chunks in batch_docs |
| ] |
|
|
| inputs = self.tokenizer( |
| doc_strings, |
| padding=True, |
| truncation=True, |
| return_tensors="pt", |
| ) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
| outputs = self.forward(**inputs) |
| token_embeddings = outputs.last_hidden_state |
|
|
| batch_chunk_embeddings = self._extract_chunks_from_concatenated( |
| input_ids=inputs["input_ids"], |
| token_embeddings=token_embeddings, |
| attention_mask=inputs["attention_mask"], |
| ) |
|
|
| batch_chunk_embeddings = [ |
| torch.stack([chunk for chunk in doc_chunks], dim=0) |
| for doc_chunks in batch_chunk_embeddings |
| ] |
|
|
| batch_chunk_embeddings = [ |
| self._flexible_quantizer( |
| {"sentence_embedding": emb}, quantization=quantization |
| )["sentence_embedding"] |
| for emb in batch_chunk_embeddings |
| ] |
|
|
| if normalize_embeddings: |
| batch_chunk_embeddings = [ |
| torch.nn.functional.normalize(emb, p=2, dim=-1) |
| for emb in batch_chunk_embeddings |
| ] |
|
|
| batch_chunk_embeddings = [emb.cpu() for emb in batch_chunk_embeddings] |
|
|
| all_embeddings.extend(batch_chunk_embeddings) |
|
|
| if convert_to_numpy: |
| all_embeddings = [emb.numpy() for emb in all_embeddings] |
|
|
| return all_embeddings |
|
|
| def _extract_chunks_from_concatenated( |
| self, |
| input_ids: torch.Tensor, |
| token_embeddings: torch.Tensor, |
| attention_mask: torch.Tensor, |
| ) -> list[list[torch.Tensor]]: |
| """ |
| Extract individual chunk embeddings from concatenated sequence using late chunking. |
| |
| This method splits concatenated sequences like "[chunk1][SEP][chunk2][SEP]..." |
| back into individual chunk embeddings by finding SEP token positions. |
| |
| Args: |
| input_ids: Token IDs (batch_size, seq_len) |
| token_embeddings: Token embeddings (batch_size, seq_len, hidden_dim) |
| attention_mask: Attention mask (batch_size, seq_len) |
| |
| Returns: |
| list[list[torch.Tensor]]: List of documents, each containing list of chunk embeddings |
| |
| Note: |
| The sep_token_id is retrieved from self.tokenizer.sep_token_id. |
| Common values: Qwen2=151643, BERT=102, varies by tokenizer. |
| """ |
| sep_token_id = self.tokenizer.sep_token_id |
| batch_size = input_ids.shape[0] |
|
|
| all_doc_chunks = [] |
|
|
| for batch_idx in range(batch_size): |
| |
| valid_positions = attention_mask[batch_idx].bool() |
| sep_positions = ( |
| (input_ids[batch_idx] == sep_token_id) & valid_positions |
| ).nonzero(as_tuple=True)[0] |
|
|
| chunk_embeddings = [] |
| start_pos = 0 |
|
|
| for sep_pos in sep_positions: |
| chunk_tokens = token_embeddings[batch_idx, start_pos:sep_pos] |
| chunk_mask = attention_mask[batch_idx, start_pos:sep_pos] |
|
|
| chunk_emb = self.mean_pooling( |
| chunk_tokens.unsqueeze(0), chunk_mask.unsqueeze(0) |
| ).squeeze(0) |
|
|
| chunk_embeddings.append(chunk_emb) |
|
|
| start_pos = sep_pos + 1 |
|
|
| |
| last_valid_pos = attention_mask[batch_idx].sum().item() |
|
|
| chunk_tokens = token_embeddings[batch_idx, start_pos:last_valid_pos] |
| chunk_mask = attention_mask[batch_idx, start_pos:last_valid_pos] |
|
|
| if chunk_mask.sum() > 0: |
| chunk_emb = self.mean_pooling( |
| chunk_tokens.unsqueeze(0), chunk_mask.unsqueeze(0) |
| ).squeeze(0) |
| else: |
| |
| chunk_emb = torch.zeros( |
| token_embeddings.shape[-1], |
| device=token_embeddings.device, |
| dtype=token_embeddings.dtype, |
| ) |
|
|
| chunk_embeddings.append(chunk_emb) |
|
|
| all_doc_chunks.append(chunk_embeddings) |
|
|
| return all_doc_chunks |
|
|