QRRanker / README.md
MindscapeRAG's picture
Update README.md
80d5bdb verified
---
license: apache-2.0
datasets:
- hotpotqa/hotpot_qa
- dgslibisey/MuSiQue
- Aman279/Locomo
- Phospheneser/DetectiveQA
language:
- en
- zh
metrics:
- accuracy
- exact_match
- f1
- recall
base_model:
- Qwen/Qwen3-4B-Instruct-2507
pipeline_tag: text-ranking
tags:
- Rerank
- Memory
---
# QRRanker: Query-focused and Memory-aware Reranker for Long Context Processing
<p align="center">
<a href="https://qdcassie-li.github.io/QRRanker/"><b>🌐 Project Page</b></a> |
<a href="https://arxiv.org/abs/2602.12192"><b>📄 Paper</b></a> |
<a href="https://huggingface.co/MindscapeRAG/QRRanker"><b>🤗 Models</b></a>
</p>
QRRanker is a lightweight reranking framework that leverages **Query-focused Retrieval (QR) heads** to produce continuous relevance scores, enabling effective listwise reranking with small-scale models.
## Model Description
Built upon the existing analysis of retrieval heads in large language models, QRRanker trains models to estimate passage–query relevance using the attention scores of selected **Query-focused Retrieval (QR) heads**. These heads are identified through QR score computation on seed data and are particularly effective at capturing query-document relevance signals.
Our approach provides a **listwise solution** that leverages the holistic information within the entire candidate shortlist during ranking. It naturally produces **continuous relevance scores**, enabling training on arbitrary retrieval datasets without requiring Likert-scale supervision.
### Key Features
- **Listwise Reranking**: Leverages holistic information within the entire candidate shortlist during ranking
- **Continuous Relevance Scores**: Enables training on arbitrary retrieval datasets without requiring Likert-scale supervision
- **Selective Head Usage**: Focuses on top-performing QR attention heads
- **Memory Enhancement**: Optional contextual summaries for improved accuracy on long narratives and dialogues
## Quick Start
### Basic Usage
```python
import torch
from transformers import AutoModel, AutoConfig, AutoTokenizer
# Load model
config = AutoConfig.from_pretrained("MindscapeRAG/QRRanker", trust_remote_code=True)
model = AutoModel.from_pretrained(
"MindscapeRAG/QRRanker",
config=config,
torch_dtype=torch.float16,
trust_remote_code=True,
)
model.eval()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("MindscapeRAG/QRRanker", trust_remote_code=True)
```
## Input Data Format
Input data should be in JSON format. Each sample contains the following fields:
```json
{
"id": "sample_001",
"question": "What is the capital of France?",
"answer": "Paris",
"paragraphs": [
{
"idx": 0,
"title": "France",
"paragraph_text": "Paris is the capital and largest city of France...",
"is_supporting": true
},
{
"idx": 1,
"title": "Germany",
"paragraph_text": "Berlin is the capital of Germany...",
"is_supporting": false
}
],
"summary": "Optional summary text..."
}
```
### Field Description
| Field | Type | Required | Description |
|-------|------|----------|-------------|
| `id` | string | Yes | Unique sample identifier |
| `question` | string | Yes | User query/question |
| `answer` | string | No | Ground truth answer (for evaluation) |
| `paragraphs` | list | Yes | List of candidate paragraphs |
| `paragraphs[].idx` | int | Yes | Paragraph index |
| `paragraphs[].title` | string | No | Paragraph title |
| `paragraphs[].paragraph_text` | string | Yes | Paragraph content |
| `paragraphs[].is_supporting` | bool | No | Whether it's a supporting paragraph (for evaluation) |
| `summary` | string | No | Optional summary information |
## Core Algorithm
### 0. DynamicCacheWithQuery (Custom Cache Class)
This custom cache class is essential for QRRanker. It extends the standard `DynamicCache` to also store query states at specified positions.
```python
from typing import Any, Dict, Optional, Tuple
from transformers.cache_utils import DynamicCache
import torch
class DynamicCacheWithQuery(DynamicCache):
"""
Custom cache class for QRRanker that stores both key/value states and query states.
The query states are extracted at specified token positions for attention computation.
"""
def __init__(self, query_indices=[]) -> None:
super().__init__()
self._query_indices = query_indices # Token indices where query states should be saved
self.query_cache = []
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with new key_states, value_states, and optionally query_states.
Parameters:
key_states: New key states to cache [batch, num_kv_heads, seq_len, head_dim]
value_states: New value states to cache [batch, num_kv_heads, seq_len, head_dim]
layer_idx: Index of the layer
cache_kwargs: Optional dict containing 'query_states' to cache
Returns:
Tuple of (updated_key_states, updated_value_states)
"""
# Update seen tokens count
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
# Update key/value cache
if key_states is not None:
if len(self.key_cache) <= layer_idx:
for _ in range(len(self.key_cache), layer_idx):
self.key_cache.append(torch.tensor([]))
self.value_cache.append(torch.tensor([]))
self.key_cache.append(key_states)
self.value_cache.append(value_states)
elif not self.key_cache[layer_idx].numel():
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat(
[self.key_cache[layer_idx], key_states], dim=-2
)
self.value_cache[layer_idx] = torch.cat(
[self.value_cache[layer_idx], value_states], dim=-2
)
# Update query cache if query_states provided
if cache_kwargs is not None:
query_states = cache_kwargs.get("query_states", None)
else:
query_states = None
if query_states is not None:
if len(self.query_cache) <= layer_idx:
self.query_cache.append(query_states)
else:
self.query_cache[layer_idx] = torch.cat(
[self.query_cache[layer_idx], query_states], dim=-2
)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
```
### 1. Attention Weight Computation
```python
import math
import torch
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Expand key/value states to match the number of query heads."""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def get_causal_mask(attn_weights):
"""Generate causal attention mask."""
query_len, seq_len = attn_weights.size(-2), attn_weights.size(-1)
causal_mask = torch.ones_like(attn_weights.transpose(-1, -2).squeeze(0))
causal_mask = torch.triu(causal_mask, diagonal=-(seq_len - query_len))
causal_mask = causal_mask.transpose(-1, -2)
causal_mask = (1 - causal_mask) * torch.finfo(causal_mask.dtype).min
return causal_mask
def get_attn_weights(key_states, query_states):
"""Compute attention weights between query and key states."""
bsz, num_heads, q_len, head_dim = query_states.size()
num_key_value_heads = key_states.size(1)
num_key_value_groups = num_heads // num_key_value_heads
kv_seq_len = key_states.size(-2)
# Expand key states to match query heads
key_states = repeat_kv(key_states, num_key_value_groups)
# Scaled dot-product attention
scale = 1.0 / math.sqrt(head_dim)
scaled_queries = query_states * scale
attn_weights = torch.matmul(scaled_queries, key_states.transpose(2, 3))
# Apply causal mask
causal_mask = get_causal_mask(attn_weights).to(attn_weights.device)
attn_weights += causal_mask.unsqueeze(0)
# Softmax normalization
attn_lses = torch.logsumexp(attn_weights, dim=-1, keepdim=True)
attn_weights = torch.exp(attn_weights - attn_lses)
return attn_weights
```
### 2. QRRanker Score Computation
```python
def compute_qr_scores(
query_cache,
key_cache,
qr_head_list,
chunk_ranges,
query_upper_bound,
):
"""
Compute QRRanker attention scores for document chunks.
Args:
query_cache: List of query states from each layer
key_cache: List of key states from each layer
qr_head_list: String of QR heads, e.g., "20-15,21-11,17-27,..."
chunk_ranges: List of [start, end] token positions for each chunk
query_upper_bound: Upper bound token position for query
Returns:
scores: Tensor of shape [num_chunks] with relevance scores
"""
all_head_scores = []
for key_state, query_state in zip(key_cache, query_cache):
# Compute attention weights
attn_weights = get_attn_weights(
key_state[:, :, :query_upper_bound, :],
query_state
)
# Average over query positions
attn_weights = attn_weights.mean(dim=-2)
# Aggregate scores for each chunk
chunk_scores = []
for start, end in chunk_ranges:
chunk_scores.append(attn_weights[:, :, start:end].sum(dim=-1))
chunk_scores = torch.stack(chunk_scores, dim=2)
all_head_scores.append(chunk_scores)
# Stack all layers: [batch, num_layers, num_heads, num_chunks]
all_head_scores = torch.stack(all_head_scores, dim=1).float()
# Select specific QR heads
if qr_head_list is not None:
head_set = [tuple(map(int, h.split('-'))) for h in qr_head_list.split(',')]
indices = torch.tensor(head_set).to(all_head_scores.device)
layers, heads = indices[:, 0], indices[:, 1]
all_head_scores = all_head_scores[:, layers, heads, :]
# Sum over selected heads
scores = all_head_scores.sum(dim=1).squeeze(0)
return scores
```
### 3. Complete Inference Pipeline
```python
from custom_cache_new import DynamicCacheWithQuery
def rerank_documents(model, tokenizer, question, paragraphs, qr_head_list, device):
"""
Rerank documents based on QRRanker scores.
Args:
model: QRRanker model
tokenizer: Tokenizer
question: Query string
paragraphs: List of paragraph dicts with 'idx' and 'paragraph_text'
qr_head_list: QR head list string (e.g., "20-15,21-11,17-27,...")
device: torch device
Returns:
ranked_ids: List of paragraph IDs sorted by relevance
scores: Corresponding relevance scores
"""
# Build input sequence
prompt_prefix = '<|im_start|>user\n'
retrieval_instruction = "Here are some retrieved chunks:\n\n"
chunk_part = prompt_prefix + retrieval_instruction
chunk_ranges = []
for i, p in enumerate(paragraphs):
text = p.get('title', '') + ': ' + p['paragraph_text']
chunk_part += f"[{i+1}]"
start = len(chunk_part)
chunk_part += ' ' + text.strip()
end = len(chunk_part)
chunk_ranges.append([start, end])
chunk_part += '\n\n'
query_part = f"Use the retrieved chunks to answer the user's query.\n\nQuery: {question}"
full_seq = chunk_part + query_part
# Tokenize with offset mapping
inputs = tokenizer(
full_seq,
max_length=262144,
truncation=True,
return_tensors='pt',
return_offsets_mapping=True,
add_special_tokens=False
)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
offset_mapping = inputs['offset_mapping'][0]
# Build character-to-token mapping
char_to_token = {}
for i, (s, e) in enumerate(offset_mapping):
for j in range(s, e):
char_to_token[j] = i
# Map chunk character ranges to token ranges
token_chunk_ranges = []
for start, end in chunk_ranges:
token_start = char_to_token.get(start, 0)
token_end = char_to_token.get(end - 1, 0) + 1
token_chunk_ranges.append([token_start, token_end])
# Get query token positions
query_start_char = full_seq.index(question)
query_end_char = query_start_char + len(question) - 1
query_positions = list(range(
char_to_token[query_start_char],
char_to_token[query_end_char] + 1
))
query_upper_bound = query_positions[-1] + 1
# Forward pass with custom cache
with torch.no_grad():
# Initialize cache with query token positions
past_kv = DynamicCacheWithQuery(query_indices=query_positions)
# Run model forward pass
output = model(input_ids, attention_mask, past_key_values=past_kv)
# Extract query and key states from cache
query_cache = output.past_key_values.query_cache
key_cache = output.past_key_values.key_cache
# Compute relevance scores
scores = compute_qr_scores(
query_cache, key_cache,
qr_head_list, token_chunk_ranges, query_upper_bound
)
# Sort by scores (descending)
sorted_indices = torch.argsort(scores, descending=True).cpu().tolist()
ranked_ids = [paragraphs[i]['idx'] for i in sorted_indices]
ranked_scores = [float(scores[i]) for i in sorted_indices]
return ranked_ids, ranked_scores
```
## Model Configuration
The model configuration includes the following QRRanker-specific parameters:
| Parameter | Description |
|-----------|-------------|
| `qr_start_layer` | Starting layer index for QR heads |
| `qr_end_layer` | Ending layer index for QR heads |
| `qr_head_list` | List of (layer, head) tuples for top QR heads |
### Default Top-16 QR Heads
```
20-15, 21-11, 17-27, 23-10, 22-4, 21-10, 21-8, 21-18,
18-15, 18-19, 17-25, 17-17, 24-13, 17-4, 19-12, 21-31
```
## Command Line Usage
```bash
# Basic inference
python qr_ranker_inference.py \
--base_model MindscapeRAG/QRRanker \
--data_path /path/to/data.json \
--mode top16
# With summary
python qr_ranker_inference.py \
--base_model MindscapeRAG/QRRanker \
--data_path /path/to/data.json \
--mode top16 \
--use_summary
```
### Arguments
| Argument | Type | Default | Description |
|----------|------|---------|-------------|
| `--base_model` | str | required | Path to QRRanker model |
| `--data_path` | str | required | Path to input data file |
| `--output_dir` | str | `./outputs` | Output directory |
| `--mode` | str | `top16` | Mode: `full` (all heads) or `top16` (selected heads) |
| `--qr_head_list` | str | None | Custom QR head list |
| `--use_summary` | flag | False | Use summary field in data |
If you use our QRRanker, please kindly cite:
```bibtex
@misc{li2026queryfocusedmemoryawarererankerlong,
title={Query-focused and Memory-aware Reranker for Long Context Processing},
author={Yuqing Li and Jiangnan Li and Mo Yu and Guoxuan Ding and Zheng Lin and Weiping Wang and Jie Zhou},
year={2026},
eprint={2602.12192},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2602.12192},
}
```
## License
This project is licensed under the Apache 2.0 License.