|
|
--- |
|
|
license: apache-2.0 |
|
|
datasets: |
|
|
- hotpotqa/hotpot_qa |
|
|
- dgslibisey/MuSiQue |
|
|
- Aman279/Locomo |
|
|
- Phospheneser/DetectiveQA |
|
|
language: |
|
|
- en |
|
|
- zh |
|
|
metrics: |
|
|
- accuracy |
|
|
- exact_match |
|
|
- f1 |
|
|
- recall |
|
|
base_model: |
|
|
- Qwen/Qwen3-4B-Instruct-2507 |
|
|
pipeline_tag: text-ranking |
|
|
tags: |
|
|
- Rerank |
|
|
- Memory |
|
|
--- |
|
|
# QRRanker: Query-focused and Memory-aware Reranker for Long Context Processing |
|
|
|
|
|
<p align="center"> |
|
|
<a href="https://qdcassie-li.github.io/QRRanker/"><b>🌐 Project Page</b></a> | |
|
|
<a href="https://arxiv.org/abs/2602.12192"><b>📄 Paper</b></a> | |
|
|
<a href="https://huggingface.co/MindscapeRAG/QRRanker"><b>🤗 Models</b></a> |
|
|
</p> |
|
|
|
|
|
QRRanker is a lightweight reranking framework that leverages **Query-focused Retrieval (QR) heads** to produce continuous relevance scores, enabling effective listwise reranking with small-scale models. |
|
|
|
|
|
|
|
|
## Model Description |
|
|
|
|
|
Built upon the existing analysis of retrieval heads in large language models, QRRanker trains models to estimate passage–query relevance using the attention scores of selected **Query-focused Retrieval (QR) heads**. These heads are identified through QR score computation on seed data and are particularly effective at capturing query-document relevance signals. |
|
|
|
|
|
Our approach provides a **listwise solution** that leverages the holistic information within the entire candidate shortlist during ranking. It naturally produces **continuous relevance scores**, enabling training on arbitrary retrieval datasets without requiring Likert-scale supervision. |
|
|
|
|
|
### Key Features |
|
|
|
|
|
- **Listwise Reranking**: Leverages holistic information within the entire candidate shortlist during ranking |
|
|
- **Continuous Relevance Scores**: Enables training on arbitrary retrieval datasets without requiring Likert-scale supervision |
|
|
- **Selective Head Usage**: Focuses on top-performing QR attention heads |
|
|
- **Memory Enhancement**: Optional contextual summaries for improved accuracy on long narratives and dialogues |
|
|
|
|
|
## Quick Start |
|
|
|
|
|
|
|
|
### Basic Usage |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from transformers import AutoModel, AutoConfig, AutoTokenizer |
|
|
|
|
|
# Load model |
|
|
config = AutoConfig.from_pretrained("MindscapeRAG/QRRanker", trust_remote_code=True) |
|
|
model = AutoModel.from_pretrained( |
|
|
"MindscapeRAG/QRRanker", |
|
|
config=config, |
|
|
torch_dtype=torch.float16, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
model.eval() |
|
|
|
|
|
# Load tokenizer |
|
|
tokenizer = AutoTokenizer.from_pretrained("MindscapeRAG/QRRanker", trust_remote_code=True) |
|
|
``` |
|
|
|
|
|
## Input Data Format |
|
|
|
|
|
Input data should be in JSON format. Each sample contains the following fields: |
|
|
|
|
|
```json |
|
|
{ |
|
|
"id": "sample_001", |
|
|
"question": "What is the capital of France?", |
|
|
"answer": "Paris", |
|
|
"paragraphs": [ |
|
|
{ |
|
|
"idx": 0, |
|
|
"title": "France", |
|
|
"paragraph_text": "Paris is the capital and largest city of France...", |
|
|
"is_supporting": true |
|
|
}, |
|
|
{ |
|
|
"idx": 1, |
|
|
"title": "Germany", |
|
|
"paragraph_text": "Berlin is the capital of Germany...", |
|
|
"is_supporting": false |
|
|
} |
|
|
], |
|
|
"summary": "Optional summary text..." |
|
|
} |
|
|
``` |
|
|
|
|
|
### Field Description |
|
|
|
|
|
| Field | Type | Required | Description | |
|
|
|-------|------|----------|-------------| |
|
|
| `id` | string | Yes | Unique sample identifier | |
|
|
| `question` | string | Yes | User query/question | |
|
|
| `answer` | string | No | Ground truth answer (for evaluation) | |
|
|
| `paragraphs` | list | Yes | List of candidate paragraphs | |
|
|
| `paragraphs[].idx` | int | Yes | Paragraph index | |
|
|
| `paragraphs[].title` | string | No | Paragraph title | |
|
|
| `paragraphs[].paragraph_text` | string | Yes | Paragraph content | |
|
|
| `paragraphs[].is_supporting` | bool | No | Whether it's a supporting paragraph (for evaluation) | |
|
|
| `summary` | string | No | Optional summary information | |
|
|
|
|
|
## Core Algorithm |
|
|
|
|
|
### 0. DynamicCacheWithQuery (Custom Cache Class) |
|
|
|
|
|
This custom cache class is essential for QRRanker. It extends the standard `DynamicCache` to also store query states at specified positions. |
|
|
|
|
|
|
|
|
```python |
|
|
from typing import Any, Dict, Optional, Tuple |
|
|
from transformers.cache_utils import DynamicCache |
|
|
import torch |
|
|
|
|
|
|
|
|
class DynamicCacheWithQuery(DynamicCache): |
|
|
""" |
|
|
Custom cache class for QRRanker that stores both key/value states and query states. |
|
|
The query states are extracted at specified token positions for attention computation. |
|
|
""" |
|
|
|
|
|
def __init__(self, query_indices=[]) -> None: |
|
|
super().__init__() |
|
|
self._query_indices = query_indices # Token indices where query states should be saved |
|
|
self.query_cache = [] |
|
|
|
|
|
def update( |
|
|
self, |
|
|
key_states: torch.Tensor, |
|
|
value_states: torch.Tensor, |
|
|
layer_idx: int, |
|
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Updates the cache with new key_states, value_states, and optionally query_states. |
|
|
|
|
|
Parameters: |
|
|
key_states: New key states to cache [batch, num_kv_heads, seq_len, head_dim] |
|
|
value_states: New value states to cache [batch, num_kv_heads, seq_len, head_dim] |
|
|
layer_idx: Index of the layer |
|
|
cache_kwargs: Optional dict containing 'query_states' to cache |
|
|
|
|
|
Returns: |
|
|
Tuple of (updated_key_states, updated_value_states) |
|
|
""" |
|
|
# Update seen tokens count |
|
|
if layer_idx == 0: |
|
|
self._seen_tokens += key_states.shape[-2] |
|
|
|
|
|
# Update key/value cache |
|
|
if key_states is not None: |
|
|
if len(self.key_cache) <= layer_idx: |
|
|
for _ in range(len(self.key_cache), layer_idx): |
|
|
self.key_cache.append(torch.tensor([])) |
|
|
self.value_cache.append(torch.tensor([])) |
|
|
self.key_cache.append(key_states) |
|
|
self.value_cache.append(value_states) |
|
|
elif not self.key_cache[layer_idx].numel(): |
|
|
self.key_cache[layer_idx] = key_states |
|
|
self.value_cache[layer_idx] = value_states |
|
|
else: |
|
|
self.key_cache[layer_idx] = torch.cat( |
|
|
[self.key_cache[layer_idx], key_states], dim=-2 |
|
|
) |
|
|
self.value_cache[layer_idx] = torch.cat( |
|
|
[self.value_cache[layer_idx], value_states], dim=-2 |
|
|
) |
|
|
|
|
|
# Update query cache if query_states provided |
|
|
if cache_kwargs is not None: |
|
|
query_states = cache_kwargs.get("query_states", None) |
|
|
else: |
|
|
query_states = None |
|
|
|
|
|
if query_states is not None: |
|
|
if len(self.query_cache) <= layer_idx: |
|
|
self.query_cache.append(query_states) |
|
|
else: |
|
|
self.query_cache[layer_idx] = torch.cat( |
|
|
[self.query_cache[layer_idx], query_states], dim=-2 |
|
|
) |
|
|
|
|
|
return self.key_cache[layer_idx], self.value_cache[layer_idx] |
|
|
``` |
|
|
|
|
|
### 1. Attention Weight Computation |
|
|
|
|
|
```python |
|
|
import math |
|
|
import torch |
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
|
"""Expand key/value states to match the number of query heads.""" |
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
|
if n_rep == 1: |
|
|
return hidden_states |
|
|
hidden_states = hidden_states[:, :, None, :, :].expand( |
|
|
batch, num_key_value_heads, n_rep, slen, head_dim |
|
|
) |
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
|
|
|
def get_causal_mask(attn_weights): |
|
|
"""Generate causal attention mask.""" |
|
|
query_len, seq_len = attn_weights.size(-2), attn_weights.size(-1) |
|
|
causal_mask = torch.ones_like(attn_weights.transpose(-1, -2).squeeze(0)) |
|
|
causal_mask = torch.triu(causal_mask, diagonal=-(seq_len - query_len)) |
|
|
causal_mask = causal_mask.transpose(-1, -2) |
|
|
causal_mask = (1 - causal_mask) * torch.finfo(causal_mask.dtype).min |
|
|
return causal_mask |
|
|
|
|
|
|
|
|
def get_attn_weights(key_states, query_states): |
|
|
"""Compute attention weights between query and key states.""" |
|
|
bsz, num_heads, q_len, head_dim = query_states.size() |
|
|
num_key_value_heads = key_states.size(1) |
|
|
num_key_value_groups = num_heads // num_key_value_heads |
|
|
kv_seq_len = key_states.size(-2) |
|
|
|
|
|
# Expand key states to match query heads |
|
|
key_states = repeat_kv(key_states, num_key_value_groups) |
|
|
|
|
|
# Scaled dot-product attention |
|
|
scale = 1.0 / math.sqrt(head_dim) |
|
|
scaled_queries = query_states * scale |
|
|
attn_weights = torch.matmul(scaled_queries, key_states.transpose(2, 3)) |
|
|
|
|
|
# Apply causal mask |
|
|
causal_mask = get_causal_mask(attn_weights).to(attn_weights.device) |
|
|
attn_weights += causal_mask.unsqueeze(0) |
|
|
|
|
|
# Softmax normalization |
|
|
attn_lses = torch.logsumexp(attn_weights, dim=-1, keepdim=True) |
|
|
attn_weights = torch.exp(attn_weights - attn_lses) |
|
|
|
|
|
return attn_weights |
|
|
``` |
|
|
|
|
|
### 2. QRRanker Score Computation |
|
|
|
|
|
```python |
|
|
def compute_qr_scores( |
|
|
query_cache, |
|
|
key_cache, |
|
|
qr_head_list, |
|
|
chunk_ranges, |
|
|
query_upper_bound, |
|
|
): |
|
|
""" |
|
|
Compute QRRanker attention scores for document chunks. |
|
|
|
|
|
Args: |
|
|
query_cache: List of query states from each layer |
|
|
key_cache: List of key states from each layer |
|
|
qr_head_list: String of QR heads, e.g., "20-15,21-11,17-27,..." |
|
|
chunk_ranges: List of [start, end] token positions for each chunk |
|
|
query_upper_bound: Upper bound token position for query |
|
|
|
|
|
Returns: |
|
|
scores: Tensor of shape [num_chunks] with relevance scores |
|
|
""" |
|
|
all_head_scores = [] |
|
|
|
|
|
for key_state, query_state in zip(key_cache, query_cache): |
|
|
# Compute attention weights |
|
|
attn_weights = get_attn_weights( |
|
|
key_state[:, :, :query_upper_bound, :], |
|
|
query_state |
|
|
) |
|
|
# Average over query positions |
|
|
attn_weights = attn_weights.mean(dim=-2) |
|
|
|
|
|
# Aggregate scores for each chunk |
|
|
chunk_scores = [] |
|
|
for start, end in chunk_ranges: |
|
|
chunk_scores.append(attn_weights[:, :, start:end].sum(dim=-1)) |
|
|
chunk_scores = torch.stack(chunk_scores, dim=2) |
|
|
all_head_scores.append(chunk_scores) |
|
|
|
|
|
# Stack all layers: [batch, num_layers, num_heads, num_chunks] |
|
|
all_head_scores = torch.stack(all_head_scores, dim=1).float() |
|
|
|
|
|
# Select specific QR heads |
|
|
if qr_head_list is not None: |
|
|
head_set = [tuple(map(int, h.split('-'))) for h in qr_head_list.split(',')] |
|
|
indices = torch.tensor(head_set).to(all_head_scores.device) |
|
|
layers, heads = indices[:, 0], indices[:, 1] |
|
|
all_head_scores = all_head_scores[:, layers, heads, :] |
|
|
|
|
|
# Sum over selected heads |
|
|
scores = all_head_scores.sum(dim=1).squeeze(0) |
|
|
|
|
|
return scores |
|
|
``` |
|
|
|
|
|
### 3. Complete Inference Pipeline |
|
|
|
|
|
```python |
|
|
from custom_cache_new import DynamicCacheWithQuery |
|
|
|
|
|
def rerank_documents(model, tokenizer, question, paragraphs, qr_head_list, device): |
|
|
""" |
|
|
Rerank documents based on QRRanker scores. |
|
|
|
|
|
Args: |
|
|
model: QRRanker model |
|
|
tokenizer: Tokenizer |
|
|
question: Query string |
|
|
paragraphs: List of paragraph dicts with 'idx' and 'paragraph_text' |
|
|
qr_head_list: QR head list string (e.g., "20-15,21-11,17-27,...") |
|
|
device: torch device |
|
|
|
|
|
Returns: |
|
|
ranked_ids: List of paragraph IDs sorted by relevance |
|
|
scores: Corresponding relevance scores |
|
|
""" |
|
|
# Build input sequence |
|
|
prompt_prefix = '<|im_start|>user\n' |
|
|
retrieval_instruction = "Here are some retrieved chunks:\n\n" |
|
|
|
|
|
chunk_part = prompt_prefix + retrieval_instruction |
|
|
chunk_ranges = [] |
|
|
|
|
|
for i, p in enumerate(paragraphs): |
|
|
text = p.get('title', '') + ': ' + p['paragraph_text'] |
|
|
chunk_part += f"[{i+1}]" |
|
|
start = len(chunk_part) |
|
|
chunk_part += ' ' + text.strip() |
|
|
end = len(chunk_part) |
|
|
chunk_ranges.append([start, end]) |
|
|
chunk_part += '\n\n' |
|
|
|
|
|
query_part = f"Use the retrieved chunks to answer the user's query.\n\nQuery: {question}" |
|
|
full_seq = chunk_part + query_part |
|
|
|
|
|
# Tokenize with offset mapping |
|
|
inputs = tokenizer( |
|
|
full_seq, |
|
|
max_length=262144, |
|
|
truncation=True, |
|
|
return_tensors='pt', |
|
|
return_offsets_mapping=True, |
|
|
add_special_tokens=False |
|
|
) |
|
|
|
|
|
input_ids = inputs['input_ids'].to(device) |
|
|
attention_mask = inputs['attention_mask'].to(device) |
|
|
offset_mapping = inputs['offset_mapping'][0] |
|
|
|
|
|
# Build character-to-token mapping |
|
|
char_to_token = {} |
|
|
for i, (s, e) in enumerate(offset_mapping): |
|
|
for j in range(s, e): |
|
|
char_to_token[j] = i |
|
|
|
|
|
# Map chunk character ranges to token ranges |
|
|
token_chunk_ranges = [] |
|
|
for start, end in chunk_ranges: |
|
|
token_start = char_to_token.get(start, 0) |
|
|
token_end = char_to_token.get(end - 1, 0) + 1 |
|
|
token_chunk_ranges.append([token_start, token_end]) |
|
|
|
|
|
# Get query token positions |
|
|
query_start_char = full_seq.index(question) |
|
|
query_end_char = query_start_char + len(question) - 1 |
|
|
query_positions = list(range( |
|
|
char_to_token[query_start_char], |
|
|
char_to_token[query_end_char] + 1 |
|
|
)) |
|
|
query_upper_bound = query_positions[-1] + 1 |
|
|
|
|
|
# Forward pass with custom cache |
|
|
with torch.no_grad(): |
|
|
# Initialize cache with query token positions |
|
|
past_kv = DynamicCacheWithQuery(query_indices=query_positions) |
|
|
|
|
|
# Run model forward pass |
|
|
output = model(input_ids, attention_mask, past_key_values=past_kv) |
|
|
|
|
|
# Extract query and key states from cache |
|
|
query_cache = output.past_key_values.query_cache |
|
|
key_cache = output.past_key_values.key_cache |
|
|
|
|
|
# Compute relevance scores |
|
|
scores = compute_qr_scores( |
|
|
query_cache, key_cache, |
|
|
qr_head_list, token_chunk_ranges, query_upper_bound |
|
|
) |
|
|
|
|
|
# Sort by scores (descending) |
|
|
sorted_indices = torch.argsort(scores, descending=True).cpu().tolist() |
|
|
ranked_ids = [paragraphs[i]['idx'] for i in sorted_indices] |
|
|
ranked_scores = [float(scores[i]) for i in sorted_indices] |
|
|
|
|
|
return ranked_ids, ranked_scores |
|
|
``` |
|
|
|
|
|
## Model Configuration |
|
|
|
|
|
The model configuration includes the following QRRanker-specific parameters: |
|
|
|
|
|
| Parameter | Description | |
|
|
|-----------|-------------| |
|
|
| `qr_start_layer` | Starting layer index for QR heads | |
|
|
| `qr_end_layer` | Ending layer index for QR heads | |
|
|
| `qr_head_list` | List of (layer, head) tuples for top QR heads | |
|
|
|
|
|
### Default Top-16 QR Heads |
|
|
|
|
|
``` |
|
|
20-15, 21-11, 17-27, 23-10, 22-4, 21-10, 21-8, 21-18, |
|
|
18-15, 18-19, 17-25, 17-17, 24-13, 17-4, 19-12, 21-31 |
|
|
``` |
|
|
|
|
|
## Command Line Usage |
|
|
|
|
|
```bash |
|
|
# Basic inference |
|
|
python qr_ranker_inference.py \ |
|
|
--base_model MindscapeRAG/QRRanker \ |
|
|
--data_path /path/to/data.json \ |
|
|
--mode top16 |
|
|
|
|
|
# With summary |
|
|
python qr_ranker_inference.py \ |
|
|
--base_model MindscapeRAG/QRRanker \ |
|
|
--data_path /path/to/data.json \ |
|
|
--mode top16 \ |
|
|
--use_summary |
|
|
``` |
|
|
|
|
|
### Arguments |
|
|
|
|
|
| Argument | Type | Default | Description | |
|
|
|----------|------|---------|-------------| |
|
|
| `--base_model` | str | required | Path to QRRanker model | |
|
|
| `--data_path` | str | required | Path to input data file | |
|
|
| `--output_dir` | str | `./outputs` | Output directory | |
|
|
| `--mode` | str | `top16` | Mode: `full` (all heads) or `top16` (selected heads) | |
|
|
| `--qr_head_list` | str | None | Custom QR head list | |
|
|
| `--use_summary` | flag | False | Use summary field in data | |
|
|
|
|
|
|
|
|
|
|
|
If you use our QRRanker, please kindly cite: |
|
|
|
|
|
```bibtex |
|
|
@misc{li2026queryfocusedmemoryawarererankerlong, |
|
|
title={Query-focused and Memory-aware Reranker for Long Context Processing}, |
|
|
author={Yuqing Li and Jiangnan Li and Mo Yu and Guoxuan Ding and Zheng Lin and Weiping Wang and Jie Zhou}, |
|
|
year={2026}, |
|
|
eprint={2602.12192}, |
|
|
archivePrefix={arXiv}, |
|
|
primaryClass={cs.CL}, |
|
|
url={https://arxiv.org/abs/2602.12192}, |
|
|
} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
This project is licensed under the Apache 2.0 License. |