--- license: apache-2.0 datasets: - hotpotqa/hotpot_qa - dgslibisey/MuSiQue - Aman279/Locomo - Phospheneser/DetectiveQA language: - en - zh metrics: - accuracy - exact_match - f1 - recall base_model: - Qwen/Qwen3-4B-Instruct-2507 pipeline_tag: text-ranking tags: - Rerank - Memory --- # QRRanker: Query-focused and Memory-aware Reranker for Long Context Processing

🌐 Project Page | 📄 Paper | 🤗 Models

QRRanker is a lightweight reranking framework that leverages **Query-focused Retrieval (QR) heads** to produce continuous relevance scores, enabling effective listwise reranking with small-scale models. ## Model Description Built upon the existing analysis of retrieval heads in large language models, QRRanker trains models to estimate passage–query relevance using the attention scores of selected **Query-focused Retrieval (QR) heads**. These heads are identified through QR score computation on seed data and are particularly effective at capturing query-document relevance signals. Our approach provides a **listwise solution** that leverages the holistic information within the entire candidate shortlist during ranking. It naturally produces **continuous relevance scores**, enabling training on arbitrary retrieval datasets without requiring Likert-scale supervision. ### Key Features - **Listwise Reranking**: Leverages holistic information within the entire candidate shortlist during ranking - **Continuous Relevance Scores**: Enables training on arbitrary retrieval datasets without requiring Likert-scale supervision - **Selective Head Usage**: Focuses on top-performing QR attention heads - **Memory Enhancement**: Optional contextual summaries for improved accuracy on long narratives and dialogues ## Quick Start ### Basic Usage ```python import torch from transformers import AutoModel, AutoConfig, AutoTokenizer # Load model config = AutoConfig.from_pretrained("MindscapeRAG/QRRanker", trust_remote_code=True) model = AutoModel.from_pretrained( "MindscapeRAG/QRRanker", config=config, torch_dtype=torch.float16, trust_remote_code=True, ) model.eval() # Load tokenizer tokenizer = AutoTokenizer.from_pretrained("MindscapeRAG/QRRanker", trust_remote_code=True) ``` ## Input Data Format Input data should be in JSON format. Each sample contains the following fields: ```json { "id": "sample_001", "question": "What is the capital of France?", "answer": "Paris", "paragraphs": [ { "idx": 0, "title": "France", "paragraph_text": "Paris is the capital and largest city of France...", "is_supporting": true }, { "idx": 1, "title": "Germany", "paragraph_text": "Berlin is the capital of Germany...", "is_supporting": false } ], "summary": "Optional summary text..." } ``` ### Field Description | Field | Type | Required | Description | |-------|------|----------|-------------| | `id` | string | Yes | Unique sample identifier | | `question` | string | Yes | User query/question | | `answer` | string | No | Ground truth answer (for evaluation) | | `paragraphs` | list | Yes | List of candidate paragraphs | | `paragraphs[].idx` | int | Yes | Paragraph index | | `paragraphs[].title` | string | No | Paragraph title | | `paragraphs[].paragraph_text` | string | Yes | Paragraph content | | `paragraphs[].is_supporting` | bool | No | Whether it's a supporting paragraph (for evaluation) | | `summary` | string | No | Optional summary information | ## Core Algorithm ### 0. DynamicCacheWithQuery (Custom Cache Class) This custom cache class is essential for QRRanker. It extends the standard `DynamicCache` to also store query states at specified positions. ```python from typing import Any, Dict, Optional, Tuple from transformers.cache_utils import DynamicCache import torch class DynamicCacheWithQuery(DynamicCache): """ Custom cache class for QRRanker that stores both key/value states and query states. The query states are extracted at specified token positions for attention computation. """ def __init__(self, query_indices=[]) -> None: super().__init__() self._query_indices = query_indices # Token indices where query states should be saved self.query_cache = [] def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with new key_states, value_states, and optionally query_states. Parameters: key_states: New key states to cache [batch, num_kv_heads, seq_len, head_dim] value_states: New value states to cache [batch, num_kv_heads, seq_len, head_dim] layer_idx: Index of the layer cache_kwargs: Optional dict containing 'query_states' to cache Returns: Tuple of (updated_key_states, updated_value_states) """ # Update seen tokens count if layer_idx == 0: self._seen_tokens += key_states.shape[-2] # Update key/value cache if key_states is not None: if len(self.key_cache) <= layer_idx: for _ in range(len(self.key_cache), layer_idx): self.key_cache.append(torch.tensor([])) self.value_cache.append(torch.tensor([])) self.key_cache.append(key_states) self.value_cache.append(value_states) elif not self.key_cache[layer_idx].numel(): self.key_cache[layer_idx] = key_states self.value_cache[layer_idx] = value_states else: self.key_cache[layer_idx] = torch.cat( [self.key_cache[layer_idx], key_states], dim=-2 ) self.value_cache[layer_idx] = torch.cat( [self.value_cache[layer_idx], value_states], dim=-2 ) # Update query cache if query_states provided if cache_kwargs is not None: query_states = cache_kwargs.get("query_states", None) else: query_states = None if query_states is not None: if len(self.query_cache) <= layer_idx: self.query_cache.append(query_states) else: self.query_cache[layer_idx] = torch.cat( [self.query_cache[layer_idx], query_states], dim=-2 ) return self.key_cache[layer_idx], self.value_cache[layer_idx] ``` ### 1. Attention Weight Computation ```python import math import torch def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """Expand key/value states to match the number of query heads.""" batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def get_causal_mask(attn_weights): """Generate causal attention mask.""" query_len, seq_len = attn_weights.size(-2), attn_weights.size(-1) causal_mask = torch.ones_like(attn_weights.transpose(-1, -2).squeeze(0)) causal_mask = torch.triu(causal_mask, diagonal=-(seq_len - query_len)) causal_mask = causal_mask.transpose(-1, -2) causal_mask = (1 - causal_mask) * torch.finfo(causal_mask.dtype).min return causal_mask def get_attn_weights(key_states, query_states): """Compute attention weights between query and key states.""" bsz, num_heads, q_len, head_dim = query_states.size() num_key_value_heads = key_states.size(1) num_key_value_groups = num_heads // num_key_value_heads kv_seq_len = key_states.size(-2) # Expand key states to match query heads key_states = repeat_kv(key_states, num_key_value_groups) # Scaled dot-product attention scale = 1.0 / math.sqrt(head_dim) scaled_queries = query_states * scale attn_weights = torch.matmul(scaled_queries, key_states.transpose(2, 3)) # Apply causal mask causal_mask = get_causal_mask(attn_weights).to(attn_weights.device) attn_weights += causal_mask.unsqueeze(0) # Softmax normalization attn_lses = torch.logsumexp(attn_weights, dim=-1, keepdim=True) attn_weights = torch.exp(attn_weights - attn_lses) return attn_weights ``` ### 2. QRRanker Score Computation ```python def compute_qr_scores( query_cache, key_cache, qr_head_list, chunk_ranges, query_upper_bound, ): """ Compute QRRanker attention scores for document chunks. Args: query_cache: List of query states from each layer key_cache: List of key states from each layer qr_head_list: String of QR heads, e.g., "20-15,21-11,17-27,..." chunk_ranges: List of [start, end] token positions for each chunk query_upper_bound: Upper bound token position for query Returns: scores: Tensor of shape [num_chunks] with relevance scores """ all_head_scores = [] for key_state, query_state in zip(key_cache, query_cache): # Compute attention weights attn_weights = get_attn_weights( key_state[:, :, :query_upper_bound, :], query_state ) # Average over query positions attn_weights = attn_weights.mean(dim=-2) # Aggregate scores for each chunk chunk_scores = [] for start, end in chunk_ranges: chunk_scores.append(attn_weights[:, :, start:end].sum(dim=-1)) chunk_scores = torch.stack(chunk_scores, dim=2) all_head_scores.append(chunk_scores) # Stack all layers: [batch, num_layers, num_heads, num_chunks] all_head_scores = torch.stack(all_head_scores, dim=1).float() # Select specific QR heads if qr_head_list is not None: head_set = [tuple(map(int, h.split('-'))) for h in qr_head_list.split(',')] indices = torch.tensor(head_set).to(all_head_scores.device) layers, heads = indices[:, 0], indices[:, 1] all_head_scores = all_head_scores[:, layers, heads, :] # Sum over selected heads scores = all_head_scores.sum(dim=1).squeeze(0) return scores ``` ### 3. Complete Inference Pipeline ```python from custom_cache_new import DynamicCacheWithQuery def rerank_documents(model, tokenizer, question, paragraphs, qr_head_list, device): """ Rerank documents based on QRRanker scores. Args: model: QRRanker model tokenizer: Tokenizer question: Query string paragraphs: List of paragraph dicts with 'idx' and 'paragraph_text' qr_head_list: QR head list string (e.g., "20-15,21-11,17-27,...") device: torch device Returns: ranked_ids: List of paragraph IDs sorted by relevance scores: Corresponding relevance scores """ # Build input sequence prompt_prefix = '<|im_start|>user\n' retrieval_instruction = "Here are some retrieved chunks:\n\n" chunk_part = prompt_prefix + retrieval_instruction chunk_ranges = [] for i, p in enumerate(paragraphs): text = p.get('title', '') + ': ' + p['paragraph_text'] chunk_part += f"[{i+1}]" start = len(chunk_part) chunk_part += ' ' + text.strip() end = len(chunk_part) chunk_ranges.append([start, end]) chunk_part += '\n\n' query_part = f"Use the retrieved chunks to answer the user's query.\n\nQuery: {question}" full_seq = chunk_part + query_part # Tokenize with offset mapping inputs = tokenizer( full_seq, max_length=262144, truncation=True, return_tensors='pt', return_offsets_mapping=True, add_special_tokens=False ) input_ids = inputs['input_ids'].to(device) attention_mask = inputs['attention_mask'].to(device) offset_mapping = inputs['offset_mapping'][0] # Build character-to-token mapping char_to_token = {} for i, (s, e) in enumerate(offset_mapping): for j in range(s, e): char_to_token[j] = i # Map chunk character ranges to token ranges token_chunk_ranges = [] for start, end in chunk_ranges: token_start = char_to_token.get(start, 0) token_end = char_to_token.get(end - 1, 0) + 1 token_chunk_ranges.append([token_start, token_end]) # Get query token positions query_start_char = full_seq.index(question) query_end_char = query_start_char + len(question) - 1 query_positions = list(range( char_to_token[query_start_char], char_to_token[query_end_char] + 1 )) query_upper_bound = query_positions[-1] + 1 # Forward pass with custom cache with torch.no_grad(): # Initialize cache with query token positions past_kv = DynamicCacheWithQuery(query_indices=query_positions) # Run model forward pass output = model(input_ids, attention_mask, past_key_values=past_kv) # Extract query and key states from cache query_cache = output.past_key_values.query_cache key_cache = output.past_key_values.key_cache # Compute relevance scores scores = compute_qr_scores( query_cache, key_cache, qr_head_list, token_chunk_ranges, query_upper_bound ) # Sort by scores (descending) sorted_indices = torch.argsort(scores, descending=True).cpu().tolist() ranked_ids = [paragraphs[i]['idx'] for i in sorted_indices] ranked_scores = [float(scores[i]) for i in sorted_indices] return ranked_ids, ranked_scores ``` ## Model Configuration The model configuration includes the following QRRanker-specific parameters: | Parameter | Description | |-----------|-------------| | `qr_start_layer` | Starting layer index for QR heads | | `qr_end_layer` | Ending layer index for QR heads | | `qr_head_list` | List of (layer, head) tuples for top QR heads | ### Default Top-16 QR Heads ``` 20-15, 21-11, 17-27, 23-10, 22-4, 21-10, 21-8, 21-18, 18-15, 18-19, 17-25, 17-17, 24-13, 17-4, 19-12, 21-31 ``` ## Command Line Usage ```bash # Basic inference python qr_ranker_inference.py \ --base_model MindscapeRAG/QRRanker \ --data_path /path/to/data.json \ --mode top16 # With summary python qr_ranker_inference.py \ --base_model MindscapeRAG/QRRanker \ --data_path /path/to/data.json \ --mode top16 \ --use_summary ``` ### Arguments | Argument | Type | Default | Description | |----------|------|---------|-------------| | `--base_model` | str | required | Path to QRRanker model | | `--data_path` | str | required | Path to input data file | | `--output_dir` | str | `./outputs` | Output directory | | `--mode` | str | `top16` | Mode: `full` (all heads) or `top16` (selected heads) | | `--qr_head_list` | str | None | Custom QR head list | | `--use_summary` | flag | False | Use summary field in data | If you use our QRRanker, please kindly cite: ```bibtex @misc{li2026queryfocusedmemoryawarererankerlong, title={Query-focused and Memory-aware Reranker for Long Context Processing}, author={Yuqing Li and Jiangnan Li and Mo Yu and Guoxuan Ding and Zheng Lin and Weiping Wang and Jie Zhou}, year={2026}, eprint={2602.12192}, archivePrefix={arXiv}, primaryClass={cs.CL}, url={https://arxiv.org/abs/2602.12192}, } ``` ## License This project is licensed under the Apache 2.0 License.