Text Ranking
Transformers
Safetensors
multilingual
t5gemma2
text2text-generation
reranker
encoder-decoder
FBNL
Retrieval
RAG
Instructions to use KaLM-Embedding/KaLM-Reranker-V1-Small with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use KaLM-Embedding/KaLM-Reranker-V1-Small with Transformers:
# Load model directly from transformers import AutoProcessor, AutoModelForMultimodalLM processor = AutoProcessor.from_pretrained("KaLM-Embedding/KaLM-Reranker-V1-Small") model = AutoModelForMultimodalLM.from_pretrained("KaLM-Embedding/KaLM-Reranker-V1-Small") - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| from typing import Any, Dict, List, Optional, Sequence, Tuple, Union | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from transformers.modeling_outputs import BaseModelOutput | |
| DEFAULT_INSTRUCTION = "Given a query, retrieve documents that answer the query." | |
| DEFAULT_SYSTEM_INSTRUCTION = ( | |
| "Judge whether the Document meets the requirements based on the Query and " | |
| 'the Instruct provided. Note that the answer can only be "yes" or "no".' | |
| ) | |
| class KaLMReranker: | |
| """Score query-document relevance with a KaLM encoder-decoder reranker. | |
| The returned score is ``P(yes)`` after applying a two-class softmax to the | |
| model's ``yes`` and ``no`` logits. | |
| """ | |
| def __init__( | |
| self, | |
| model_name_or_path: str, | |
| *, | |
| device: Optional[Union[str, torch.device]] = None, | |
| dtype: Optional[Union[str, torch.dtype]] = None, | |
| batch_size: int = 32, | |
| query_max_length: int = 512, | |
| max_length: int = 1024, | |
| chunk_size: Optional[int] = 4, | |
| instruction: str = DEFAULT_INSTRUCTION, | |
| system_instruction: str = DEFAULT_SYSTEM_INSTRUCTION, | |
| **model_kwargs: Any, | |
| ) -> None: | |
| if not isinstance(model_name_or_path, str) or not model_name_or_path: | |
| raise ValueError("model_name_or_path must be a non-empty string.") | |
| if batch_size <= 0: | |
| raise ValueError("batch_size must be positive.") | |
| if query_max_length <= 0 or max_length <= 0: | |
| raise ValueError("query_max_length and max_length must be positive.") | |
| if chunk_size is not None and chunk_size <= 0: | |
| raise ValueError("chunk_size must be positive or None.") | |
| if not isinstance(instruction, str) or not isinstance(system_instruction, str): | |
| raise TypeError("instruction and system_instruction must be strings.") | |
| self.device = self._resolve_device(device) | |
| self.dtype = self._resolve_dtype(dtype, self.device) | |
| self.batch_size = batch_size | |
| self.query_max_length = query_max_length | |
| self.max_length = max_length | |
| self.chunk_size = chunk_size | |
| self.instruction = instruction | |
| self.system_instruction = system_instruction | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
| if self.tokenizer.pad_token_id is None: | |
| if self.tokenizer.eos_token_id is None: | |
| raise ValueError("The tokenizer must define a pad token or an EOS token.") | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.tokenizer.padding_side = "right" | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained( | |
| model_name_or_path, | |
| dtype=self.dtype, | |
| **model_kwargs, | |
| ) | |
| for parameter in self.model.parameters(): | |
| if parameter.is_floating_point() and parameter.dtype != self.dtype: | |
| parameter.data = parameter.data.to(dtype=self.dtype) | |
| self.model.to(device=self.device) | |
| self.model.eval() | |
| self.yes_token_id = self._answer_token_id("yes") | |
| self.no_token_id = self._answer_token_id("no") | |
| def _resolve_device(device: Optional[Union[str, torch.device]]) -> torch.device: | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| resolved = torch.device(device) | |
| if resolved.type == "cuda" and not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA was requested, but no CUDA device is available.") | |
| return resolved | |
| def _resolve_dtype( | |
| dtype: Optional[Union[str, torch.dtype]], device: torch.device | |
| ) -> torch.dtype: | |
| if dtype is None: | |
| return torch.bfloat16 if device.type == "cuda" else torch.float32 | |
| if isinstance(dtype, torch.dtype): | |
| return dtype | |
| if not isinstance(dtype, str): | |
| raise TypeError("dtype must be a torch.dtype or a string such as 'bfloat16'.") | |
| normalized = dtype.lower().removeprefix("torch.") | |
| supported = { | |
| "bfloat16": torch.bfloat16, | |
| "bf16": torch.bfloat16, | |
| "float16": torch.float16, | |
| "fp16": torch.float16, | |
| "float32": torch.float32, | |
| "fp32": torch.float32, | |
| } | |
| if normalized not in supported: | |
| raise ValueError(f"Unsupported dtype: {dtype!r}.") | |
| return supported[normalized] | |
| def _answer_token_id(self, answer: str) -> int: | |
| token_ids = self.tokenizer(answer, add_special_tokens=False)["input_ids"] | |
| if not token_ids: | |
| raise ValueError(f"Failed to tokenize the answer {answer!r}.") | |
| return token_ids[-1] | |
| def _get_encoder(self): | |
| if hasattr(self.model, "get_encoder"): | |
| return self.model.get_encoder() | |
| if hasattr(self.model, "encoder"): | |
| return self.model.encoder | |
| raise AttributeError(f"Cannot find the encoder on {type(self.model).__name__}.") | |
| def _pool_encoder_chunks( | |
| hidden_states: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| chunk_size: int, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| batch_size, sequence_length, hidden_size = hidden_states.shape | |
| num_chunks = (sequence_length + chunk_size - 1) // chunk_size | |
| padded_length = num_chunks * chunk_size | |
| pad_length = padded_length - sequence_length | |
| if pad_length: | |
| hidden_states = F.pad(hidden_states, (0, 0, 0, pad_length)) | |
| attention_mask = F.pad(attention_mask, (0, pad_length)) | |
| hidden_states = hidden_states.view( | |
| batch_size, num_chunks, chunk_size, hidden_size | |
| ) | |
| chunk_mask = attention_mask.view(batch_size, num_chunks, chunk_size) | |
| expanded_mask = chunk_mask.unsqueeze(-1).to(hidden_states.dtype) | |
| pooled_hidden = (hidden_states * expanded_mask).sum(dim=2) | |
| pooled_hidden = pooled_hidden / chunk_mask.sum(dim=2).clamp(min=1).unsqueeze(-1) | |
| pooled_mask = (chunk_mask.sum(dim=2) > 0).to(attention_mask.dtype) | |
| return pooled_hidden, pooled_mask | |
| def _decoder_text(self, query: str, instruction: str) -> str: | |
| query_ids = self.tokenizer( | |
| query, | |
| add_special_tokens=False, | |
| truncation=True, | |
| max_length=self.query_max_length, | |
| )["input_ids"] | |
| truncated_query = self.tokenizer.decode( | |
| query_ids, | |
| skip_special_tokens=False, | |
| clean_up_tokenization_spaces=False, | |
| ) | |
| return ( | |
| "<bos><start_of_turn>user\n" | |
| f"{self.system_instruction}\n\n" | |
| f"<Instruct>: {instruction}\n" | |
| f"<Query>: {truncated_query}<end_of_turn>\n" | |
| "<start_of_turn>model\n\n\n\n" | |
| ) | |
| def _validate_pairs( | |
| pairs: Sequence[Tuple[str, str]], | |
| ) -> List[Tuple[str, str]]: | |
| if isinstance(pairs, (str, bytes)) or not isinstance(pairs, Sequence): | |
| raise TypeError("pairs must be a sequence of (query, document) pairs.") | |
| validated: List[Tuple[str, str]] = [] | |
| for index, pair in enumerate(pairs): | |
| if ( | |
| isinstance(pair, (str, bytes)) | |
| or not isinstance(pair, Sequence) | |
| or len(pair) != 2 | |
| ): | |
| raise ValueError(f"pairs[{index}] must contain exactly two strings.") | |
| query, document = pair | |
| if not isinstance(query, str) or not isinstance(document, str): | |
| raise TypeError(f"pairs[{index}] must contain exactly two strings.") | |
| validated.append((query, document)) | |
| return validated | |
| def _predict_batch( | |
| self, pairs: Sequence[Tuple[str, str]], instruction: str | |
| ) -> List[float]: | |
| encoder_texts = [f"<Document>: {document}" for _, document in pairs] | |
| decoder_texts = [self._decoder_text(query, instruction) for query, _ in pairs] | |
| encoder_batch = self.tokenizer( | |
| encoder_texts, | |
| padding=True, | |
| truncation=True, | |
| max_length=self.max_length, | |
| add_special_tokens=False, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| decoder_batch = self.tokenizer( | |
| decoder_texts, | |
| padding=True, | |
| pad_to_multiple_of=8, | |
| add_special_tokens=False, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| if self.chunk_size is None: | |
| outputs = self.model( | |
| input_ids=encoder_batch["input_ids"], | |
| attention_mask=encoder_batch["attention_mask"], | |
| decoder_input_ids=decoder_batch["input_ids"], | |
| decoder_attention_mask=decoder_batch["attention_mask"], | |
| return_dict=True, | |
| ) | |
| else: | |
| encoder_outputs = self._get_encoder()( | |
| input_ids=encoder_batch["input_ids"], | |
| attention_mask=encoder_batch["attention_mask"], | |
| return_dict=True, | |
| ) | |
| pooled_hidden, pooled_mask = self._pool_encoder_chunks( | |
| encoder_outputs.last_hidden_state, | |
| encoder_batch["attention_mask"], | |
| self.chunk_size, | |
| ) | |
| outputs = self.model( | |
| encoder_outputs=BaseModelOutput(last_hidden_state=pooled_hidden), | |
| attention_mask=pooled_mask, | |
| decoder_input_ids=decoder_batch["input_ids"], | |
| decoder_attention_mask=decoder_batch["attention_mask"], | |
| return_dict=True, | |
| ) | |
| sequence_lengths = decoder_batch["attention_mask"].sum(dim=1) - 1 | |
| batch_indices = torch.arange(outputs.logits.shape[0], device=self.device) | |
| last_logits = outputs.logits[batch_indices, sequence_lengths] | |
| yes_no_logits = torch.stack( | |
| ( | |
| last_logits[:, self.yes_token_id], | |
| last_logits[:, self.no_token_id], | |
| ), | |
| dim=-1, | |
| ).float() | |
| if not torch.isfinite(yes_no_logits).all(): | |
| bad_count = (~torch.isfinite(yes_no_logits).all(dim=-1)).sum().item() | |
| raise RuntimeError( | |
| f"The model produced non-finite yes/no logits for {bad_count} input(s). " | |
| "Use bfloat16 or float32 instead of float16." | |
| ) | |
| return torch.softmax(yes_no_logits, dim=-1)[:, 0].cpu().tolist() | |
| def predict( | |
| self, | |
| pairs: Sequence[Tuple[str, str]], | |
| *, | |
| instruction: Optional[str] = None, | |
| batch_size: Optional[int] = None, | |
| ) -> List[float]: | |
| """Return ``P(yes)`` scores in the same order as ``pairs``.""" | |
| validated_pairs = self._validate_pairs(pairs) | |
| if not validated_pairs: | |
| return [] | |
| effective_instruction = self.instruction if instruction is None else instruction | |
| if not isinstance(effective_instruction, str): | |
| raise TypeError("instruction must be a string or None.") | |
| effective_batch_size = self.batch_size if batch_size is None else batch_size | |
| if not isinstance(effective_batch_size, int) or effective_batch_size <= 0: | |
| raise ValueError("batch_size must be a positive integer.") | |
| length_sorted_indices = np.argsort( | |
| [-(len(query) + len(document)) for query, document in validated_pairs] | |
| ) | |
| sorted_pairs = [validated_pairs[index] for index in length_sorted_indices] | |
| tested_batch_size = effective_batch_size | |
| while tested_batch_size > 1: | |
| try: | |
| self._predict_batch( | |
| sorted_pairs[: min(len(sorted_pairs), tested_batch_size)], | |
| effective_instruction, | |
| ) | |
| break | |
| except torch.cuda.OutOfMemoryError: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| tested_batch_size = max(1, tested_batch_size * 3 // 4) | |
| sorted_scores: List[float] = [] | |
| try: | |
| for start in range(0, len(sorted_pairs), tested_batch_size): | |
| sorted_scores.extend( | |
| self._predict_batch( | |
| sorted_pairs[start : start + tested_batch_size], | |
| effective_instruction, | |
| ) | |
| ) | |
| except torch.cuda.OutOfMemoryError as error: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| raise RuntimeError( | |
| "CUDA ran out of memory during reranking. Retry with a smaller batch_size " | |
| "or shorter max_length." | |
| ) from error | |
| inverse_indices = np.argsort(length_sorted_indices) | |
| return [sorted_scores[index] for index in inverse_indices] | |
| def rank( | |
| self, | |
| query: str, | |
| documents: Sequence[str], | |
| *, | |
| instruction: Optional[str] = None, | |
| top_k: Optional[int] = None, | |
| batch_size: Optional[int] = None, | |
| ) -> List[Dict[str, Union[int, float]]]: | |
| """Rank documents and return ``corpus_id``/``score`` dictionaries.""" | |
| if not isinstance(query, str): | |
| raise TypeError("query must be a string.") | |
| if isinstance(documents, (str, bytes)) or not isinstance(documents, Sequence): | |
| raise TypeError("documents must be a sequence of strings.") | |
| if any(not isinstance(document, str) for document in documents): | |
| raise TypeError("every document must be a string.") | |
| if top_k is not None and (not isinstance(top_k, int) or top_k < 0): | |
| raise ValueError("top_k must be a non-negative integer or None.") | |
| scores = self.predict( | |
| [(query, document) for document in documents], | |
| instruction=instruction, | |
| batch_size=batch_size, | |
| ) | |
| rankings: List[Dict[str, Union[int, float]]] = [ | |
| {"corpus_id": corpus_id, "score": score} | |
| for corpus_id, score in enumerate(scores) | |
| ] | |
| rankings.sort(key=lambda item: item["score"], reverse=True) | |
| return rankings if top_k is None else rankings[:top_k] | |
| __all__ = ["KaLMReranker"] | |