| --- |
| base_model: |
| - Qwen/Qwen3-4B-Instruct-2507 |
| datasets: |
| - hotpotqa/hotpot_qa |
| - dgslibisey/MuSiQue |
| - Aman279/Locomo |
| - Phospheneser/DetectiveQA |
| language: |
| - en |
| - zh |
| license: apache-2.0 |
| metrics: |
| - accuracy |
| - exact_match |
| - f1 |
| - recall |
| pipeline_tag: text-ranking |
| library_name: transformers |
| tags: |
| - Rerank |
| - Memory |
| --- |
| |
| # QRRanker: Query-focused and Memory-aware Reranker for Long Context Processing |
|
|
| <p align="center"> |
| <a href="https://qdcassie-li.github.io/QRRanker/"><b>🌐 Project Page</b></a> | |
| <a href="https://arxiv.org/abs/2602.12192"><b>📄 Paper</b></a> | |
| <a href="https://huggingface.co/MindscapeRAG/QRRanker"><b>🤗 Models</b></a> |
| </p> |
|
|
| QRRanker is a lightweight reranking framework that leverages **Query-focused Retrieval (QR) heads** to produce continuous relevance scores, enabling effective listwise reranking with small-scale models. It was introduced in the paper [Query-focused and Memory-aware Reranker for Long Context Processing](https://huggingface.co/papers/2602.12192). |
|
|
| ## Model Description |
|
|
| Built upon the existing analysis of retrieval heads in large language models, QRRanker trains models to estimate passage–query relevance using the attention scores of selected **Query-focused Retrieval (QR) heads**. These heads are identified through QR score computation on seed data and are particularly effective at capturing query-document relevance signals. |
|
|
| Our approach provides a **listwise solution** that leverages the holistic information within the entire candidate shortlist during ranking. It naturally produces **continuous relevance scores**, enabling training on arbitrary retrieval datasets without requiring Likert-scale supervision. |
|
|
| ### Key Features |
|
|
| - **Listwise Reranking**: Leverages holistic information within the entire candidate shortlist during ranking |
| - **Continuous Relevance Scores**: Enables training on arbitrary retrieval datasets without requiring Likert-scale supervision |
| - **Selective Head Usage**: Focuses on top-performing QR attention heads |
| - **Memory Enhancement**: Optional contextual summaries for improved accuracy on long narratives and dialogues |
|
|
| ## Quick Start |
|
|
|
|
| ### Basic Usage |
|
|
| ```python |
| import torch |
| from transformers import AutoModel, AutoConfig, AutoTokenizer |
| |
| # Load model |
| config = AutoConfig.from_pretrained("MindscapeRAG/QRRanker", trust_remote_code=True) |
| model = AutoModel.from_pretrained( |
| "MindscapeRAG/QRRanker", |
| config=config, |
| torch_dtype=torch.float16, |
| trust_remote_code=True, |
| ) |
| model.eval() |
| |
| # Load tokenizer |
| tokenizer = AutoTokenizer.from_pretrained("MindscapeRAG/QRRanker", trust_remote_code=True) |
| ``` |
|
|
| ## Input Data Format |
|
|
| Input data should be in JSON format. Each sample contains the following fields: |
|
|
| ```json |
| { |
| "id": "sample_001", |
| "question": "What is the capital of France?", |
| "answer": "Paris", |
| "paragraphs": [ |
| { |
| "idx": 0, |
| "title": "France", |
| "paragraph_text": "Paris is the capital and largest city of France...", |
| "is_supporting": true |
| }, |
| { |
| "idx": 1, |
| "title": "Germany", |
| "paragraph_text": "Berlin is the capital of Germany...", |
| "is_supporting": false |
| } |
| ], |
| "summary": "Optional summary text..." |
| } |
| ``` |
|
|
| ### Field Description |
|
|
| | Field | Type | Required | Description | |
| |-------|------|----------|-------------| |
| | `id` | string | Yes | Unique sample identifier | |
| | `question` | string | Yes | User query/question | |
| | `answer` | string | No | Ground truth answer (for evaluation) | |
| | `paragraphs` | list | Yes | List of candidate paragraphs | |
| | `paragraphs[].idx` | int | Yes | Paragraph index | |
| | `paragraphs[].title` | string | No | Paragraph title | |
| | `paragraphs[].paragraph_text` | string | Yes | Paragraph content | |
| | `paragraphs[].is_supporting` | bool | No | Whether it's a supporting paragraph (for evaluation) | |
| | `summary` | string | No | Optional summary information | |
|
|
| ## Core Algorithm |
|
|
| ### 0. DynamicCacheWithQuery (Custom Cache Class) |
|
|
| This custom cache class is essential for QRRanker. It extends the standard `DynamicCache` to also store query states at specified positions. |
|
|
|
|
| ```python |
| from typing import Any, Dict, Optional, Tuple |
| from transformers.cache_utils import DynamicCache |
| import torch |
| |
| |
| class DynamicCacheWithQuery(DynamicCache): |
| """ |
| Custom cache class for QRRanker that stores both key/value states and query states. |
| The query states are extracted at specified token positions for attention computation. |
| """ |
| |
| def __init__(self, query_indices=[]) -> None: |
| super().__init__() |
| self._query_indices = query_indices # Token indices where query states should be saved |
| self.query_cache = [] |
| |
| def update( |
| self, |
| key_states: torch.Tensor, |
| value_states: torch.Tensor, |
| layer_idx: int, |
| cache_kwargs: Optional[Dict[str, Any]] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Updates the cache with new key_states, value_states, and optionally query_states. |
| |
| Parameters: |
| key_states: New key states to cache [batch, num_kv_heads, seq_len, head_dim] |
| value_states: New value states to cache [batch, num_kv_heads, seq_len, head_dim] |
| layer_idx: Index of the layer |
| cache_kwargs: Optional dict containing 'query_states' to cache |
| |
| Returns: |
| Tuple of (updated_key_states, updated_value_states) |
| """ |
| # Update seen tokens count |
| if layer_idx == 0: |
| self._seen_tokens += key_states.shape[-2] |
| |
| # Update key/value cache |
| if key_states is not None: |
| if len(self.key_cache) <= layer_idx: |
| for _ in range(len(self.key_cache), layer_idx): |
| self.key_cache.append(torch.tensor([])) |
| self.value_cache.append(torch.tensor([])) |
| self.key_cache.append(key_states) |
| self.value_cache.append(value_states) |
| elif not self.key_cache[layer_idx].numel(): |
| self.key_cache[layer_idx] = key_states |
| self.value_cache[layer_idx] = value_states |
| else: |
| self.key_cache[layer_idx] = torch.cat( |
| [self.key_cache[layer_idx], key_states], dim=-2 |
| ) |
| self.value_cache[layer_idx] = torch.cat( |
| [self.value_cache[layer_idx], value_states], dim=-2 |
| ) |
| |
| # Update query cache if query_states provided |
| if cache_kwargs is not None: |
| query_states = cache_kwargs.get("query_states", None) |
| else: |
| query_states = None |
| |
| if query_states is not None: |
| if len(self.query_cache) <= layer_idx: |
| self.query_cache.append(query_states) |
| else: |
| self.query_cache[layer_idx] = torch.cat( |
| [self.query_cache[layer_idx], query_states], dim=-2 |
| ) |
| |
| return self.key_cache[layer_idx], self.value_cache[layer_idx] |
| ``` |
|
|
| ### 1. Attention Weight Computation |
|
|
| ```python |
| import math |
| import torch |
| |
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """Expand key/value states to match the number of query heads.""" |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand( |
| batch, num_key_value_heads, n_rep, slen, head_dim |
| ) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
| |
| |
| def get_causal_mask(attn_weights): |
| """Generate causal attention mask.""" |
| query_len, seq_len = attn_weights.size(-2), attn_weights.size(-1) |
| causal_mask = torch.ones_like(attn_weights.transpose(-1, -2).squeeze(0)) |
| causal_mask = torch.triu(causal_mask, diagonal=-(seq_len - query_len)) |
| causal_mask = causal_mask.transpose(-1, -2) |
| causal_mask = (1 - causal_mask) * torch.finfo(causal_mask.dtype).min |
| return causal_mask |
| |
| |
| def get_attn_weights(key_states, query_states): |
| """Compute attention weights between query and key states.""" |
| bsz, num_heads, q_len, head_dim = query_states.size() |
| num_key_value_heads = key_states.size(1) |
| num_key_value_groups = num_heads // num_key_value_heads |
| kv_seq_len = key_states.size(-2) |
| |
| # Expand key states to match query heads |
| key_states = repeat_kv(key_states, num_key_value_groups) |
| |
| # Scaled dot-product attention |
| scale = 1.0 / math.sqrt(head_dim) |
| scaled_queries = query_states * scale |
| attn_weights = torch.matmul(scaled_queries, key_states.transpose(2, 3)) |
| |
| # Apply causal mask |
| causal_mask = get_causal_mask(attn_weights).to(attn_weights.device) |
| attn_weights += causal_mask.unsqueeze(0) |
| |
| # Softmax normalization |
| attn_lses = torch.logsumexp(attn_weights, dim=-1, keepdim=True) |
| attn_weights = torch.exp(attn_weights - attn_lses) |
| |
| return attn_weights |
| ``` |
|
|
| ### 2. QRRanker Score Computation |
|
|
| ```python |
| def compute_qr_scores( |
| query_cache, |
| key_cache, |
| qr_head_list, |
| chunk_ranges, |
| query_upper_bound, |
| ): |
| """ |
| Compute QRRanker attention scores for document chunks. |
| |
| Args: |
| query_cache: List of query states from each layer |
| key_cache: List of key states from each layer |
| qr_head_list: String of QR heads, e.g., "20-15,21-11,17-27,..." |
| chunk_ranges: List of [start, end] token positions for each chunk |
| query_upper_bound: Upper bound token position for query |
| |
| Returns: |
| scores: Tensor of shape [num_chunks] with relevance scores |
| """ |
| all_head_scores = [] |
| |
| for key_state, query_state in zip(key_cache, query_cache): |
| # Compute attention weights |
| attn_weights = get_attn_weights( |
| key_state[:, :, :query_upper_bound, :], |
| query_state |
| ) |
| # Average over query positions |
| attn_weights = attn_weights.mean(dim=-2) |
| |
| # Aggregate scores for each chunk |
| chunk_scores = [] |
| for start, end in chunk_ranges: |
| chunk_scores.append(attn_weights[:, :, start:end].sum(dim=-1)) |
| chunk_scores = torch.stack(chunk_scores, dim=2) |
| all_head_scores.append(chunk_scores) |
| |
| # Stack all layers: [batch, num_layers, num_heads, num_chunks] |
| all_head_scores = torch.stack(all_head_scores, dim=1).float() |
| |
| # Select specific QR heads |
| if qr_head_list is not None: |
| head_set = [tuple(map(int, h.split('-'))) for h in qr_head_list.split(',') ] |
| indices = torch.tensor(head_set).to(all_head_scores.device) |
| layers, heads = indices[:, 0], indices[:, 1] |
| all_head_scores = all_head_scores[:, layers, heads, :] |
| |
| # Sum over selected heads |
| scores = all_head_scores.sum(dim=1).squeeze(0) |
| |
| return scores |
| ``` |
|
|
| ### 3. Complete Inference Pipeline |
|
|
| ```python |
| from custom_cache_new import DynamicCacheWithQuery |
| |
| def rerank_documents(model, tokenizer, question, paragraphs, qr_head_list, device): |
| """ |
| Rerank documents based on QRRanker scores. |
| |
| Args: |
| model: QRRanker model |
| tokenizer: Tokenizer |
| question: Query string |
| paragraphs: List of paragraph dicts with 'idx' and 'paragraph_text' |
| qr_head_list: QR head list string (e.g., "20-15,21-11,17-27,...") |
| device: torch device |
| |
| Returns: |
| ranked_ids: List of paragraph IDs sorted by relevance |
| scores: Corresponding relevance scores |
| """ |
| # Build input sequence |
| prompt_prefix = '<|im_start|>user |
| ' |
| retrieval_instruction = "Here are some retrieved chunks: |
| |
| " |
| |
| chunk_part = prompt_prefix + retrieval_instruction |
| chunk_ranges = [] |
| |
| for i, p in enumerate(paragraphs): |
| text = p.get('title', '') + ': ' + p['paragraph_text'] |
| chunk_part += f"[{i+1}]" |
| start = len(chunk_part) |
| chunk_part += ' ' + text.strip() |
| end = len(chunk_part) |
| chunk_ranges.append([start, end]) |
| chunk_part += ' |
| |
| ' |
| |
| query_part = f"Use the retrieved chunks to answer the user's query. |
| |
| Query: {question}" |
| full_seq = chunk_part + query_part |
| |
| # Tokenize with offset mapping |
| inputs = tokenizer( |
| full_seq, |
| max_length=262144, |
| truncation=True, |
| return_tensors='pt', |
| return_offsets_mapping=True, |
| add_special_tokens=False |
| ) |
| |
| input_ids = inputs['input_ids'].to(device) |
| attention_mask = inputs['attention_mask'].to(device) |
| offset_mapping = inputs['offset_mapping'][0] |
| |
| # Build character-to-token mapping |
| char_to_token = {} |
| for i, (s, e) in enumerate(offset_mapping): |
| for j in range(s, e): |
| char_to_token[j] = i |
| |
| # Map chunk character ranges to token ranges |
| token_chunk_ranges = [] |
| for start, end in chunk_ranges: |
| token_start = char_to_token.get(start, 0) |
| token_end = char_to_token.get(end - 1, 0) + 1 |
| token_chunk_ranges.append([token_start, token_end]) |
| |
| # Get query token positions |
| query_start_char = full_seq.index(question) |
| query_end_char = query_start_char + len(question) - 1 |
| query_positions = list(range( |
| char_to_token[query_start_char], |
| char_to_token[query_end_char] + 1 |
| )) |
| query_upper_bound = query_positions[-1] + 1 |
| |
| # Forward pass with custom cache |
| with torch.no_grad(): |
| # Initialize cache with query token positions |
| past_kv = DynamicCacheWithQuery(query_indices=query_positions) |
| |
| # Run model forward pass |
| output = model(input_ids, attention_mask, past_key_values=past_kv) |
| |
| # Extract query and key states from cache |
| query_cache = output.past_key_values.query_cache |
| key_cache = output.past_key_values.key_cache |
| |
| # Compute relevance scores |
| scores = compute_qr_scores( |
| query_cache, key_cache, |
| qr_head_list, token_chunk_ranges, query_upper_bound |
| ) |
| |
| # Sort by scores (descending) |
| sorted_indices = torch.argsort(scores, descending=True).cpu().tolist() |
| ranked_ids = [paragraphs[i]['idx'] for i in sorted_indices] |
| ranked_scores = [float(scores[i]) for i in sorted_indices] |
| |
| return ranked_ids, ranked_scores |
| ``` |
|
|
| ## Model Configuration |
|
|
| The model configuration includes the following QRRanker-specific parameters: |
|
|
| | Parameter | Description | |
| |-----------|-------------| |
| | `qr_start_layer` | Starting layer index for QR heads | |
| | `qr_end_layer` | Ending layer index for QR heads | |
| | `qr_head_list` | List of (layer, head) tuples for top QR heads | |
|
|
| ### Default Top-16 QR Heads |
|
|
| ``` |
| 20-15, 21-11, 17-27, 23-10, 22-4, 21-10, 21-8, 21-18, |
| 18-15, 18-19, 17-25, 17-17, 24-13, 17-4, 19-12, 21-31 |
| ``` |
|
|
| ## Command Line Usage |
|
|
| ```bash |
| # Basic inference |
| python qr_ranker_inference.py \ |
| --base_model MindscapeRAG/QRRanker \ |
| --data_path /path/to/data.json \ |
| --mode top16 |
| |
| # With summary |
| python qr_ranker_inference.py \ |
| --base_model MindscapeRAG/QRRanker \ |
| --data_path /path/to/data.json \ |
| --mode top16 \ |
| --use_summary |
| ``` |
|
|
| ### Arguments |
|
|
| | Argument | Type | Default | Description | |
| |----------|------|---------|-------------| |
| | `--base_model` | str | required | Path to QRRanker model | |
| | `--data_path` | str | required | Path to input data file | |
| | `--output_dir` | str | `./outputs` | Output directory | |
| | `--mode` | str | `top16` | Mode: `full` (all heads) or `top16` (selected heads) | |
| | `--qr_head_list` | str | None | Custom QR head list | |
| | `--use_summary` | flag | False | Use summary field in data | |
|
|
|
|
| ## Citation |
|
|
| If you use our QRRanker, please kindly cite: |
|
|
| ```bibtex |
| @misc{li2026queryfocusedmemoryawarererankerlong, |
| title={Query-focused and Memory-aware Reranker for Long Context Processing}, |
| author={Yuqing Li and Jiangnan Li and Mo Yu and Guoxuan Ding and Zheng Lin and Weiping Wang and Jie Zhou}, |
| year={2026}, |
| eprint={2602.12192}, |
| archivePrefix={arXiv}, |
| primaryClass={cs.CL}, |
| url={https://arxiv.org/abs/2602.12192}, |
| } |
| ``` |
|
|
| ## License |
|
|
| This project is licensed under the Apache 2.0 License. |