Update README.md
Browse files
README.md
CHANGED
|
@@ -1,448 +1,469 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
-
|
| 21 |
-
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
```
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
#
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
self.
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
num_key_value_heads =
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
causal_mask =
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
)
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
)
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
#
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
datasets:
|
| 4 |
+
- hotpotqa/hotpot_qa
|
| 5 |
+
- dgslibisey/MuSiQue
|
| 6 |
+
- Aman279/Locomo
|
| 7 |
+
- Phospheneser/DetectiveQA
|
| 8 |
+
language:
|
| 9 |
+
- en
|
| 10 |
+
- zh
|
| 11 |
+
metrics:
|
| 12 |
+
- accuracy
|
| 13 |
+
- exact_match
|
| 14 |
+
- f1
|
| 15 |
+
- recall
|
| 16 |
+
base_model:
|
| 17 |
+
- Qwen/Qwen3-4B-Instruct-2507
|
| 18 |
+
pipeline_tag: text-ranking
|
| 19 |
+
tags:
|
| 20 |
+
- Rerank
|
| 21 |
+
- Memory
|
| 22 |
+
---
|
| 23 |
+
# QRRanker: Query-focused and Memory-aware Reranker for Long Context Processing
|
| 24 |
+
|
| 25 |
+
<p align="center">
|
| 26 |
+
<a href="https://qdcassie-li.github.io/QRRanker/"><b>🌐 Project Page</b></a> |
|
| 27 |
+
<a href="https://arxiv.org/abs/2602.12192"><b>📄 Paper</b></a> |
|
| 28 |
+
<a href="https://huggingface.co/MindscapeRAG/QRRanker"><b>🤗 Models</b></a>
|
| 29 |
+
</p>
|
| 30 |
+
|
| 31 |
+
QRRanker is a lightweight reranking framework that leverages **Query-focused Retrieval (QR) heads** to produce continuous relevance scores, enabling effective listwise reranking with small-scale models.
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
## Model Description
|
| 35 |
+
|
| 36 |
+
Built upon the existing analysis of retrieval heads in large language models, QRRanker trains models to estimate passage–query relevance using the attention scores of selected **Query-focused Retrieval (QR) heads**. These heads are identified through QR score computation on seed data and are particularly effective at capturing query-document relevance signals.
|
| 37 |
+
|
| 38 |
+
Our approach provides a **listwise solution** that leverages the holistic information within the entire candidate shortlist during ranking. It naturally produces **continuous relevance scores**, enabling training on arbitrary retrieval datasets without requiring Likert-scale supervision.
|
| 39 |
+
|
| 40 |
+
### Key Features
|
| 41 |
+
|
| 42 |
+
- **Listwise Reranking**: Leverages holistic information within the entire candidate shortlist during ranking
|
| 43 |
+
- **Continuous Relevance Scores**: Enables training on arbitrary retrieval datasets without requiring Likert-scale supervision
|
| 44 |
+
- **Selective Head Usage**: Focuses on top-performing QR attention heads
|
| 45 |
+
- **Memory Enhancement**: Optional contextual summaries for improved accuracy on long narratives and dialogues
|
| 46 |
+
|
| 47 |
+
## Quick Start
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
### Basic Usage
|
| 51 |
+
|
| 52 |
+
```python
|
| 53 |
+
import torch
|
| 54 |
+
from transformers import AutoModel, AutoConfig, AutoTokenizer
|
| 55 |
+
|
| 56 |
+
# Load model
|
| 57 |
+
config = AutoConfig.from_pretrained("MindscapeRAG/QRRanker", trust_remote_code=True)
|
| 58 |
+
model = AutoModel.from_pretrained(
|
| 59 |
+
"MindscapeRAG/QRRanker",
|
| 60 |
+
config=config,
|
| 61 |
+
torch_dtype=torch.float16,
|
| 62 |
+
trust_remote_code=True,
|
| 63 |
+
)
|
| 64 |
+
model.eval()
|
| 65 |
+
|
| 66 |
+
# Load tokenizer
|
| 67 |
+
tokenizer = AutoTokenizer.from_pretrained("MindscapeRAG/QRRanker", trust_remote_code=True)
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
## Input Data Format
|
| 71 |
+
|
| 72 |
+
Input data should be in JSON format. Each sample contains the following fields:
|
| 73 |
+
|
| 74 |
+
```json
|
| 75 |
+
{
|
| 76 |
+
"id": "sample_001",
|
| 77 |
+
"question": "What is the capital of France?",
|
| 78 |
+
"answer": "Paris",
|
| 79 |
+
"paragraphs": [
|
| 80 |
+
{
|
| 81 |
+
"idx": 0,
|
| 82 |
+
"title": "France",
|
| 83 |
+
"paragraph_text": "Paris is the capital and largest city of France...",
|
| 84 |
+
"is_supporting": true
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"idx": 1,
|
| 88 |
+
"title": "Germany",
|
| 89 |
+
"paragraph_text": "Berlin is the capital of Germany...",
|
| 90 |
+
"is_supporting": false
|
| 91 |
+
}
|
| 92 |
+
],
|
| 93 |
+
"summary": "Optional summary text..."
|
| 94 |
+
}
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
### Field Description
|
| 98 |
+
|
| 99 |
+
| Field | Type | Required | Description |
|
| 100 |
+
|-------|------|----------|-------------|
|
| 101 |
+
| `id` | string | Yes | Unique sample identifier |
|
| 102 |
+
| `question` | string | Yes | User query/question |
|
| 103 |
+
| `answer` | string | No | Ground truth answer (for evaluation) |
|
| 104 |
+
| `paragraphs` | list | Yes | List of candidate paragraphs |
|
| 105 |
+
| `paragraphs[].idx` | int | Yes | Paragraph index |
|
| 106 |
+
| `paragraphs[].title` | string | No | Paragraph title |
|
| 107 |
+
| `paragraphs[].paragraph_text` | string | Yes | Paragraph content |
|
| 108 |
+
| `paragraphs[].is_supporting` | bool | No | Whether it's a supporting paragraph (for evaluation) |
|
| 109 |
+
| `summary` | string | No | Optional summary information |
|
| 110 |
+
|
| 111 |
+
## Core Algorithm
|
| 112 |
+
|
| 113 |
+
### 0. DynamicCacheWithQuery (Custom Cache Class)
|
| 114 |
+
|
| 115 |
+
This custom cache class is essential for QRRanker. It extends the standard `DynamicCache` to also store query states at specified positions.
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
```python
|
| 119 |
+
from typing import Any, Dict, Optional, Tuple
|
| 120 |
+
from transformers.cache_utils import DynamicCache
|
| 121 |
+
import torch
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class DynamicCacheWithQuery(DynamicCache):
|
| 125 |
+
"""
|
| 126 |
+
Custom cache class for QRRanker that stores both key/value states and query states.
|
| 127 |
+
The query states are extracted at specified token positions for attention computation.
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
def __init__(self, query_indices=[]) -> None:
|
| 131 |
+
super().__init__()
|
| 132 |
+
self._query_indices = query_indices # Token indices where query states should be saved
|
| 133 |
+
self.query_cache = []
|
| 134 |
+
|
| 135 |
+
def update(
|
| 136 |
+
self,
|
| 137 |
+
key_states: torch.Tensor,
|
| 138 |
+
value_states: torch.Tensor,
|
| 139 |
+
layer_idx: int,
|
| 140 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
| 141 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 142 |
+
"""
|
| 143 |
+
Updates the cache with new key_states, value_states, and optionally query_states.
|
| 144 |
+
|
| 145 |
+
Parameters:
|
| 146 |
+
key_states: New key states to cache [batch, num_kv_heads, seq_len, head_dim]
|
| 147 |
+
value_states: New value states to cache [batch, num_kv_heads, seq_len, head_dim]
|
| 148 |
+
layer_idx: Index of the layer
|
| 149 |
+
cache_kwargs: Optional dict containing 'query_states' to cache
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
Tuple of (updated_key_states, updated_value_states)
|
| 153 |
+
"""
|
| 154 |
+
# Update seen tokens count
|
| 155 |
+
if layer_idx == 0:
|
| 156 |
+
self._seen_tokens += key_states.shape[-2]
|
| 157 |
+
|
| 158 |
+
# Update key/value cache
|
| 159 |
+
if key_states is not None:
|
| 160 |
+
if len(self.key_cache) <= layer_idx:
|
| 161 |
+
for _ in range(len(self.key_cache), layer_idx):
|
| 162 |
+
self.key_cache.append(torch.tensor([]))
|
| 163 |
+
self.value_cache.append(torch.tensor([]))
|
| 164 |
+
self.key_cache.append(key_states)
|
| 165 |
+
self.value_cache.append(value_states)
|
| 166 |
+
elif not self.key_cache[layer_idx].numel():
|
| 167 |
+
self.key_cache[layer_idx] = key_states
|
| 168 |
+
self.value_cache[layer_idx] = value_states
|
| 169 |
+
else:
|
| 170 |
+
self.key_cache[layer_idx] = torch.cat(
|
| 171 |
+
[self.key_cache[layer_idx], key_states], dim=-2
|
| 172 |
+
)
|
| 173 |
+
self.value_cache[layer_idx] = torch.cat(
|
| 174 |
+
[self.value_cache[layer_idx], value_states], dim=-2
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Update query cache if query_states provided
|
| 178 |
+
if cache_kwargs is not None:
|
| 179 |
+
query_states = cache_kwargs.get("query_states", None)
|
| 180 |
+
else:
|
| 181 |
+
query_states = None
|
| 182 |
+
|
| 183 |
+
if query_states is not None:
|
| 184 |
+
if len(self.query_cache) <= layer_idx:
|
| 185 |
+
self.query_cache.append(query_states)
|
| 186 |
+
else:
|
| 187 |
+
self.query_cache[layer_idx] = torch.cat(
|
| 188 |
+
[self.query_cache[layer_idx], query_states], dim=-2
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
### 1. Attention Weight Computation
|
| 195 |
+
|
| 196 |
+
```python
|
| 197 |
+
import math
|
| 198 |
+
import torch
|
| 199 |
+
|
| 200 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 201 |
+
"""Expand key/value states to match the number of query heads."""
|
| 202 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 203 |
+
if n_rep == 1:
|
| 204 |
+
return hidden_states
|
| 205 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
| 206 |
+
batch, num_key_value_heads, n_rep, slen, head_dim
|
| 207 |
+
)
|
| 208 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def get_causal_mask(attn_weights):
|
| 212 |
+
"""Generate causal attention mask."""
|
| 213 |
+
query_len, seq_len = attn_weights.size(-2), attn_weights.size(-1)
|
| 214 |
+
causal_mask = torch.ones_like(attn_weights.transpose(-1, -2).squeeze(0))
|
| 215 |
+
causal_mask = torch.triu(causal_mask, diagonal=-(seq_len - query_len))
|
| 216 |
+
causal_mask = causal_mask.transpose(-1, -2)
|
| 217 |
+
causal_mask = (1 - causal_mask) * torch.finfo(causal_mask.dtype).min
|
| 218 |
+
return causal_mask
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def get_attn_weights(key_states, query_states):
|
| 222 |
+
"""Compute attention weights between query and key states."""
|
| 223 |
+
bsz, num_heads, q_len, head_dim = query_states.size()
|
| 224 |
+
num_key_value_heads = key_states.size(1)
|
| 225 |
+
num_key_value_groups = num_heads // num_key_value_heads
|
| 226 |
+
kv_seq_len = key_states.size(-2)
|
| 227 |
+
|
| 228 |
+
# Expand key states to match query heads
|
| 229 |
+
key_states = repeat_kv(key_states, num_key_value_groups)
|
| 230 |
+
|
| 231 |
+
# Scaled dot-product attention
|
| 232 |
+
scale = 1.0 / math.sqrt(head_dim)
|
| 233 |
+
scaled_queries = query_states * scale
|
| 234 |
+
attn_weights = torch.matmul(scaled_queries, key_states.transpose(2, 3))
|
| 235 |
+
|
| 236 |
+
# Apply causal mask
|
| 237 |
+
causal_mask = get_causal_mask(attn_weights).to(attn_weights.device)
|
| 238 |
+
attn_weights += causal_mask.unsqueeze(0)
|
| 239 |
+
|
| 240 |
+
# Softmax normalization
|
| 241 |
+
attn_lses = torch.logsumexp(attn_weights, dim=-1, keepdim=True)
|
| 242 |
+
attn_weights = torch.exp(attn_weights - attn_lses)
|
| 243 |
+
|
| 244 |
+
return attn_weights
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
### 2. QRRanker Score Computation
|
| 248 |
+
|
| 249 |
+
```python
|
| 250 |
+
def compute_qr_scores(
|
| 251 |
+
query_cache,
|
| 252 |
+
key_cache,
|
| 253 |
+
qr_head_list,
|
| 254 |
+
chunk_ranges,
|
| 255 |
+
query_upper_bound,
|
| 256 |
+
):
|
| 257 |
+
"""
|
| 258 |
+
Compute QRRanker attention scores for document chunks.
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
query_cache: List of query states from each layer
|
| 262 |
+
key_cache: List of key states from each layer
|
| 263 |
+
qr_head_list: String of QR heads, e.g., "20-15,21-11,17-27,..."
|
| 264 |
+
chunk_ranges: List of [start, end] token positions for each chunk
|
| 265 |
+
query_upper_bound: Upper bound token position for query
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
scores: Tensor of shape [num_chunks] with relevance scores
|
| 269 |
+
"""
|
| 270 |
+
all_head_scores = []
|
| 271 |
+
|
| 272 |
+
for key_state, query_state in zip(key_cache, query_cache):
|
| 273 |
+
# Compute attention weights
|
| 274 |
+
attn_weights = get_attn_weights(
|
| 275 |
+
key_state[:, :, :query_upper_bound, :],
|
| 276 |
+
query_state
|
| 277 |
+
)
|
| 278 |
+
# Average over query positions
|
| 279 |
+
attn_weights = attn_weights.mean(dim=-2)
|
| 280 |
+
|
| 281 |
+
# Aggregate scores for each chunk
|
| 282 |
+
chunk_scores = []
|
| 283 |
+
for start, end in chunk_ranges:
|
| 284 |
+
chunk_scores.append(attn_weights[:, :, start:end].sum(dim=-1))
|
| 285 |
+
chunk_scores = torch.stack(chunk_scores, dim=2)
|
| 286 |
+
all_head_scores.append(chunk_scores)
|
| 287 |
+
|
| 288 |
+
# Stack all layers: [batch, num_layers, num_heads, num_chunks]
|
| 289 |
+
all_head_scores = torch.stack(all_head_scores, dim=1).float()
|
| 290 |
+
|
| 291 |
+
# Select specific QR heads
|
| 292 |
+
if qr_head_list is not None:
|
| 293 |
+
head_set = [tuple(map(int, h.split('-'))) for h in qr_head_list.split(',')]
|
| 294 |
+
indices = torch.tensor(head_set).to(all_head_scores.device)
|
| 295 |
+
layers, heads = indices[:, 0], indices[:, 1]
|
| 296 |
+
all_head_scores = all_head_scores[:, layers, heads, :]
|
| 297 |
+
|
| 298 |
+
# Sum over selected heads
|
| 299 |
+
scores = all_head_scores.sum(dim=1).squeeze(0)
|
| 300 |
+
|
| 301 |
+
return scores
|
| 302 |
+
```
|
| 303 |
+
|
| 304 |
+
### 3. Complete Inference Pipeline
|
| 305 |
+
|
| 306 |
+
```python
|
| 307 |
+
from custom_cache_new import DynamicCacheWithQuery
|
| 308 |
+
|
| 309 |
+
def rerank_documents(model, tokenizer, question, paragraphs, qr_head_list, device):
|
| 310 |
+
"""
|
| 311 |
+
Rerank documents based on QRRanker scores.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
model: QRRanker model
|
| 315 |
+
tokenizer: Tokenizer
|
| 316 |
+
question: Query string
|
| 317 |
+
paragraphs: List of paragraph dicts with 'idx' and 'paragraph_text'
|
| 318 |
+
qr_head_list: QR head list string (e.g., "20-15,21-11,17-27,...")
|
| 319 |
+
device: torch device
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
ranked_ids: List of paragraph IDs sorted by relevance
|
| 323 |
+
scores: Corresponding relevance scores
|
| 324 |
+
"""
|
| 325 |
+
# Build input sequence
|
| 326 |
+
prompt_prefix = '<|im_start|>user\n'
|
| 327 |
+
retrieval_instruction = "Here are some retrieved chunks:\n\n"
|
| 328 |
+
|
| 329 |
+
chunk_part = prompt_prefix + retrieval_instruction
|
| 330 |
+
chunk_ranges = []
|
| 331 |
+
|
| 332 |
+
for i, p in enumerate(paragraphs):
|
| 333 |
+
text = p.get('title', '') + ': ' + p['paragraph_text']
|
| 334 |
+
chunk_part += f"[{i+1}]"
|
| 335 |
+
start = len(chunk_part)
|
| 336 |
+
chunk_part += ' ' + text.strip()
|
| 337 |
+
end = len(chunk_part)
|
| 338 |
+
chunk_ranges.append([start, end])
|
| 339 |
+
chunk_part += '\n\n'
|
| 340 |
+
|
| 341 |
+
query_part = f"Use the retrieved chunks to answer the user's query.\n\nQuery: {question}"
|
| 342 |
+
full_seq = chunk_part + query_part
|
| 343 |
+
|
| 344 |
+
# Tokenize with offset mapping
|
| 345 |
+
inputs = tokenizer(
|
| 346 |
+
full_seq,
|
| 347 |
+
max_length=262144,
|
| 348 |
+
truncation=True,
|
| 349 |
+
return_tensors='pt',
|
| 350 |
+
return_offsets_mapping=True,
|
| 351 |
+
add_special_tokens=False
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
input_ids = inputs['input_ids'].to(device)
|
| 355 |
+
attention_mask = inputs['attention_mask'].to(device)
|
| 356 |
+
offset_mapping = inputs['offset_mapping'][0]
|
| 357 |
+
|
| 358 |
+
# Build character-to-token mapping
|
| 359 |
+
char_to_token = {}
|
| 360 |
+
for i, (s, e) in enumerate(offset_mapping):
|
| 361 |
+
for j in range(s, e):
|
| 362 |
+
char_to_token[j] = i
|
| 363 |
+
|
| 364 |
+
# Map chunk character ranges to token ranges
|
| 365 |
+
token_chunk_ranges = []
|
| 366 |
+
for start, end in chunk_ranges:
|
| 367 |
+
token_start = char_to_token.get(start, 0)
|
| 368 |
+
token_end = char_to_token.get(end - 1, 0) + 1
|
| 369 |
+
token_chunk_ranges.append([token_start, token_end])
|
| 370 |
+
|
| 371 |
+
# Get query token positions
|
| 372 |
+
query_start_char = full_seq.index(question)
|
| 373 |
+
query_end_char = query_start_char + len(question) - 1
|
| 374 |
+
query_positions = list(range(
|
| 375 |
+
char_to_token[query_start_char],
|
| 376 |
+
char_to_token[query_end_char] + 1
|
| 377 |
+
))
|
| 378 |
+
query_upper_bound = query_positions[-1] + 1
|
| 379 |
+
|
| 380 |
+
# Forward pass with custom cache
|
| 381 |
+
with torch.no_grad():
|
| 382 |
+
# Initialize cache with query token positions
|
| 383 |
+
past_kv = DynamicCacheWithQuery(query_indices=query_positions)
|
| 384 |
+
|
| 385 |
+
# Run model forward pass
|
| 386 |
+
output = model(input_ids, attention_mask, past_key_values=past_kv)
|
| 387 |
+
|
| 388 |
+
# Extract query and key states from cache
|
| 389 |
+
query_cache = output.past_key_values.query_cache
|
| 390 |
+
key_cache = output.past_key_values.key_cache
|
| 391 |
+
|
| 392 |
+
# Compute relevance scores
|
| 393 |
+
scores = compute_qr_scores(
|
| 394 |
+
query_cache, key_cache,
|
| 395 |
+
qr_head_list, token_chunk_ranges, query_upper_bound
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
# Sort by scores (descending)
|
| 399 |
+
sorted_indices = torch.argsort(scores, descending=True).cpu().tolist()
|
| 400 |
+
ranked_ids = [paragraphs[i]['idx'] for i in sorted_indices]
|
| 401 |
+
ranked_scores = [float(scores[i]) for i in sorted_indices]
|
| 402 |
+
|
| 403 |
+
return ranked_ids, ranked_scores
|
| 404 |
+
```
|
| 405 |
+
|
| 406 |
+
## Model Configuration
|
| 407 |
+
|
| 408 |
+
The model configuration includes the following QRRanker-specific parameters:
|
| 409 |
+
|
| 410 |
+
| Parameter | Description |
|
| 411 |
+
|-----------|-------------|
|
| 412 |
+
| `qr_start_layer` | Starting layer index for QR heads |
|
| 413 |
+
| `qr_end_layer` | Ending layer index for QR heads |
|
| 414 |
+
| `qr_head_list` | List of (layer, head) tuples for top QR heads |
|
| 415 |
+
|
| 416 |
+
### Default Top-16 QR Heads
|
| 417 |
+
|
| 418 |
+
```
|
| 419 |
+
20-15, 21-11, 17-27, 23-10, 22-4, 21-10, 21-8, 21-18,
|
| 420 |
+
18-15, 18-19, 17-25, 17-17, 24-13, 17-4, 19-12, 21-31
|
| 421 |
+
```
|
| 422 |
+
|
| 423 |
+
## Command Line Usage
|
| 424 |
+
|
| 425 |
+
```bash
|
| 426 |
+
# Basic inference
|
| 427 |
+
python qr_ranker_inference.py \
|
| 428 |
+
--base_model QRRanker \
|
| 429 |
+
--data_path /path/to/data.json \
|
| 430 |
+
--mode top16
|
| 431 |
+
|
| 432 |
+
# With summary
|
| 433 |
+
python qr_ranker_inference.py \
|
| 434 |
+
--base_model QRRanker \
|
| 435 |
+
--data_path /path/to/data.json \
|
| 436 |
+
--mode top16 \
|
| 437 |
+
--use_summary
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
### Arguments
|
| 441 |
+
|
| 442 |
+
| Argument | Type | Default | Description |
|
| 443 |
+
|----------|------|---------|-------------|
|
| 444 |
+
| `--base_model` | str | required | Path to QRRanker model |
|
| 445 |
+
| `--data_path` | str | required | Path to input data file |
|
| 446 |
+
| `--output_dir` | str | `./outputs` | Output directory |
|
| 447 |
+
| `--mode` | str | `top16` | Mode: `full` (all heads) or `top16` (selected heads) |
|
| 448 |
+
| `--qr_head_list` | str | None | Custom QR head list |
|
| 449 |
+
| `--use_summary` | flag | False | Use summary field in data |
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
If you use our QRRanker, please kindly cite:
|
| 454 |
+
|
| 455 |
+
```bibtex
|
| 456 |
+
@misc{li2026queryfocusedmemoryawarererankerlong,
|
| 457 |
+
title={Query-focused and Memory-aware Reranker for Long Context Processing},
|
| 458 |
+
author={Yuqing Li and Jiangnan Li and Mo Yu and Guoxuan Ding and Zheng Lin and Weiping Wang and Jie Zhou},
|
| 459 |
+
year={2026},
|
| 460 |
+
eprint={2602.12192},
|
| 461 |
+
archivePrefix={arXiv},
|
| 462 |
+
primaryClass={cs.CL},
|
| 463 |
+
url={https://arxiv.org/abs/2602.12192},
|
| 464 |
+
}
|
| 465 |
+
```
|
| 466 |
+
|
| 467 |
+
## License
|
| 468 |
+
|
| 469 |
+
This project is licensed under the Apache 2.0 License.
|