MindscapeRAG commited on
Commit
33cf4c1
·
verified ·
1 Parent(s): 6c7eda3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +469 -448
README.md CHANGED
@@ -1,448 +1,469 @@
1
- # QRRanker: Query-focused and Memory-aware Reranker for Long Context Processing
2
-
3
- <p align="center">
4
- <a href="https://qdcassie-li.github.io/QRRanker/"><b>🌐 Project Page</b></a> |
5
- <a href="https://arxiv.org/abs/2602.12192"><b>📄 Paper</b></a> |
6
- <a href="https://huggingface.co/MindscapeRAG/QRRanker"><b>🤗 Models</b></a>
7
- </p>
8
-
9
- QRRanker is a lightweight reranking framework that leverages **Query-focused Retrieval (QR) heads** to produce continuous relevance scores, enabling effective listwise reranking with small-scale models.
10
-
11
-
12
- ## Model Description
13
-
14
- Built upon the existing analysis of retrieval heads in large language models, QRRanker trains models to estimate passage–query relevance using the attention scores of selected **Query-focused Retrieval (QR) heads**. These heads are identified through QR score computation on seed data and are particularly effective at capturing query-document relevance signals.
15
-
16
- Our approach provides a **listwise solution** that leverages the holistic information within the entire candidate shortlist during ranking. It naturally produces **continuous relevance scores**, enabling training on arbitrary retrieval datasets without requiring Likert-scale supervision.
17
-
18
- ### Key Features
19
-
20
- - **Listwise Reranking**: Leverages holistic information within the entire candidate shortlist during ranking
21
- - **Continuous Relevance Scores**: Enables training on arbitrary retrieval datasets without requiring Likert-scale supervision
22
- - **Selective Head Usage**: Focuses on top-performing QR attention heads
23
- - **Memory Enhancement**: Optional contextual summaries for improved accuracy on long narratives and dialogues
24
-
25
- ## Quick Start
26
-
27
-
28
- ### Basic Usage
29
-
30
- ```python
31
- import torch
32
- from transformers import AutoModel, AutoConfig, AutoTokenizer
33
-
34
- # Load model
35
- config = AutoConfig.from_pretrained("MindscapeRAG/QRRanker", trust_remote_code=True)
36
- model = AutoModel.from_pretrained(
37
- "MindscapeRAG/QRRanker",
38
- config=config,
39
- torch_dtype=torch.float16,
40
- trust_remote_code=True,
41
- )
42
- model.eval()
43
-
44
- # Load tokenizer
45
- tokenizer = AutoTokenizer.from_pretrained("MindscapeRAG/QRRanker", trust_remote_code=True)
46
- ```
47
-
48
- ## Input Data Format
49
-
50
- Input data should be in JSON format. Each sample contains the following fields:
51
-
52
- ```json
53
- {
54
- "id": "sample_001",
55
- "question": "What is the capital of France?",
56
- "answer": "Paris",
57
- "paragraphs": [
58
- {
59
- "idx": 0,
60
- "title": "France",
61
- "paragraph_text": "Paris is the capital and largest city of France...",
62
- "is_supporting": true
63
- },
64
- {
65
- "idx": 1,
66
- "title": "Germany",
67
- "paragraph_text": "Berlin is the capital of Germany...",
68
- "is_supporting": false
69
- }
70
- ],
71
- "summary": "Optional summary text..."
72
- }
73
- ```
74
-
75
- ### Field Description
76
-
77
- | Field | Type | Required | Description |
78
- |-------|------|----------|-------------|
79
- | `id` | string | Yes | Unique sample identifier |
80
- | `question` | string | Yes | User query/question |
81
- | `answer` | string | No | Ground truth answer (for evaluation) |
82
- | `paragraphs` | list | Yes | List of candidate paragraphs |
83
- | `paragraphs[].idx` | int | Yes | Paragraph index |
84
- | `paragraphs[].title` | string | No | Paragraph title |
85
- | `paragraphs[].paragraph_text` | string | Yes | Paragraph content |
86
- | `paragraphs[].is_supporting` | bool | No | Whether it's a supporting paragraph (for evaluation) |
87
- | `summary` | string | No | Optional summary information |
88
-
89
- ## Core Algorithm
90
-
91
- ### 0. DynamicCacheWithQuery (Custom Cache Class)
92
-
93
- This custom cache class is essential for QRRanker. It extends the standard `DynamicCache` to also store query states at specified positions.
94
-
95
-
96
- ```python
97
- from typing import Any, Dict, Optional, Tuple
98
- from transformers.cache_utils import DynamicCache
99
- import torch
100
-
101
-
102
- class DynamicCacheWithQuery(DynamicCache):
103
- """
104
- Custom cache class for QRRanker that stores both key/value states and query states.
105
- The query states are extracted at specified token positions for attention computation.
106
- """
107
-
108
- def __init__(self, query_indices=[]) -> None:
109
- super().__init__()
110
- self._query_indices = query_indices # Token indices where query states should be saved
111
- self.query_cache = []
112
-
113
- def update(
114
- self,
115
- key_states: torch.Tensor,
116
- value_states: torch.Tensor,
117
- layer_idx: int,
118
- cache_kwargs: Optional[Dict[str, Any]] = None,
119
- ) -> Tuple[torch.Tensor, torch.Tensor]:
120
- """
121
- Updates the cache with new key_states, value_states, and optionally query_states.
122
-
123
- Parameters:
124
- key_states: New key states to cache [batch, num_kv_heads, seq_len, head_dim]
125
- value_states: New value states to cache [batch, num_kv_heads, seq_len, head_dim]
126
- layer_idx: Index of the layer
127
- cache_kwargs: Optional dict containing 'query_states' to cache
128
-
129
- Returns:
130
- Tuple of (updated_key_states, updated_value_states)
131
- """
132
- # Update seen tokens count
133
- if layer_idx == 0:
134
- self._seen_tokens += key_states.shape[-2]
135
-
136
- # Update key/value cache
137
- if key_states is not None:
138
- if len(self.key_cache) <= layer_idx:
139
- for _ in range(len(self.key_cache), layer_idx):
140
- self.key_cache.append(torch.tensor([]))
141
- self.value_cache.append(torch.tensor([]))
142
- self.key_cache.append(key_states)
143
- self.value_cache.append(value_states)
144
- elif not self.key_cache[layer_idx].numel():
145
- self.key_cache[layer_idx] = key_states
146
- self.value_cache[layer_idx] = value_states
147
- else:
148
- self.key_cache[layer_idx] = torch.cat(
149
- [self.key_cache[layer_idx], key_states], dim=-2
150
- )
151
- self.value_cache[layer_idx] = torch.cat(
152
- [self.value_cache[layer_idx], value_states], dim=-2
153
- )
154
-
155
- # Update query cache if query_states provided
156
- if cache_kwargs is not None:
157
- query_states = cache_kwargs.get("query_states", None)
158
- else:
159
- query_states = None
160
-
161
- if query_states is not None:
162
- if len(self.query_cache) <= layer_idx:
163
- self.query_cache.append(query_states)
164
- else:
165
- self.query_cache[layer_idx] = torch.cat(
166
- [self.query_cache[layer_idx], query_states], dim=-2
167
- )
168
-
169
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
170
- ```
171
-
172
- ### 1. Attention Weight Computation
173
-
174
- ```python
175
- import math
176
- import torch
177
-
178
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
179
- """Expand key/value states to match the number of query heads."""
180
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
181
- if n_rep == 1:
182
- return hidden_states
183
- hidden_states = hidden_states[:, :, None, :, :].expand(
184
- batch, num_key_value_heads, n_rep, slen, head_dim
185
- )
186
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
187
-
188
-
189
- def get_causal_mask(attn_weights):
190
- """Generate causal attention mask."""
191
- query_len, seq_len = attn_weights.size(-2), attn_weights.size(-1)
192
- causal_mask = torch.ones_like(attn_weights.transpose(-1, -2).squeeze(0))
193
- causal_mask = torch.triu(causal_mask, diagonal=-(seq_len - query_len))
194
- causal_mask = causal_mask.transpose(-1, -2)
195
- causal_mask = (1 - causal_mask) * torch.finfo(causal_mask.dtype).min
196
- return causal_mask
197
-
198
-
199
- def get_attn_weights(key_states, query_states):
200
- """Compute attention weights between query and key states."""
201
- bsz, num_heads, q_len, head_dim = query_states.size()
202
- num_key_value_heads = key_states.size(1)
203
- num_key_value_groups = num_heads // num_key_value_heads
204
- kv_seq_len = key_states.size(-2)
205
-
206
- # Expand key states to match query heads
207
- key_states = repeat_kv(key_states, num_key_value_groups)
208
-
209
- # Scaled dot-product attention
210
- scale = 1.0 / math.sqrt(head_dim)
211
- scaled_queries = query_states * scale
212
- attn_weights = torch.matmul(scaled_queries, key_states.transpose(2, 3))
213
-
214
- # Apply causal mask
215
- causal_mask = get_causal_mask(attn_weights).to(attn_weights.device)
216
- attn_weights += causal_mask.unsqueeze(0)
217
-
218
- # Softmax normalization
219
- attn_lses = torch.logsumexp(attn_weights, dim=-1, keepdim=True)
220
- attn_weights = torch.exp(attn_weights - attn_lses)
221
-
222
- return attn_weights
223
- ```
224
-
225
- ### 2. QRRanker Score Computation
226
-
227
- ```python
228
- def compute_qr_scores(
229
- query_cache,
230
- key_cache,
231
- qr_head_list,
232
- chunk_ranges,
233
- query_upper_bound,
234
- ):
235
- """
236
- Compute QRRanker attention scores for document chunks.
237
-
238
- Args:
239
- query_cache: List of query states from each layer
240
- key_cache: List of key states from each layer
241
- qr_head_list: String of QR heads, e.g., "20-15,21-11,17-27,..."
242
- chunk_ranges: List of [start, end] token positions for each chunk
243
- query_upper_bound: Upper bound token position for query
244
-
245
- Returns:
246
- scores: Tensor of shape [num_chunks] with relevance scores
247
- """
248
- all_head_scores = []
249
-
250
- for key_state, query_state in zip(key_cache, query_cache):
251
- # Compute attention weights
252
- attn_weights = get_attn_weights(
253
- key_state[:, :, :query_upper_bound, :],
254
- query_state
255
- )
256
- # Average over query positions
257
- attn_weights = attn_weights.mean(dim=-2)
258
-
259
- # Aggregate scores for each chunk
260
- chunk_scores = []
261
- for start, end in chunk_ranges:
262
- chunk_scores.append(attn_weights[:, :, start:end].sum(dim=-1))
263
- chunk_scores = torch.stack(chunk_scores, dim=2)
264
- all_head_scores.append(chunk_scores)
265
-
266
- # Stack all layers: [batch, num_layers, num_heads, num_chunks]
267
- all_head_scores = torch.stack(all_head_scores, dim=1).float()
268
-
269
- # Select specific QR heads
270
- if qr_head_list is not None:
271
- head_set = [tuple(map(int, h.split('-'))) for h in qr_head_list.split(',')]
272
- indices = torch.tensor(head_set).to(all_head_scores.device)
273
- layers, heads = indices[:, 0], indices[:, 1]
274
- all_head_scores = all_head_scores[:, layers, heads, :]
275
-
276
- # Sum over selected heads
277
- scores = all_head_scores.sum(dim=1).squeeze(0)
278
-
279
- return scores
280
- ```
281
-
282
- ### 3. Complete Inference Pipeline
283
-
284
- ```python
285
- from custom_cache_new import DynamicCacheWithQuery
286
-
287
- def rerank_documents(model, tokenizer, question, paragraphs, qr_head_list, device):
288
- """
289
- Rerank documents based on QRRanker scores.
290
-
291
- Args:
292
- model: QRRanker model
293
- tokenizer: Tokenizer
294
- question: Query string
295
- paragraphs: List of paragraph dicts with 'idx' and 'paragraph_text'
296
- qr_head_list: QR head list string (e.g., "20-15,21-11,17-27,...")
297
- device: torch device
298
-
299
- Returns:
300
- ranked_ids: List of paragraph IDs sorted by relevance
301
- scores: Corresponding relevance scores
302
- """
303
- # Build input sequence
304
- prompt_prefix = '<|im_start|>user\n'
305
- retrieval_instruction = "Here are some retrieved chunks:\n\n"
306
-
307
- chunk_part = prompt_prefix + retrieval_instruction
308
- chunk_ranges = []
309
-
310
- for i, p in enumerate(paragraphs):
311
- text = p.get('title', '') + ': ' + p['paragraph_text']
312
- chunk_part += f"[{i+1}]"
313
- start = len(chunk_part)
314
- chunk_part += ' ' + text.strip()
315
- end = len(chunk_part)
316
- chunk_ranges.append([start, end])
317
- chunk_part += '\n\n'
318
-
319
- query_part = f"Use the retrieved chunks to answer the user's query.\n\nQuery: {question}"
320
- full_seq = chunk_part + query_part
321
-
322
- # Tokenize with offset mapping
323
- inputs = tokenizer(
324
- full_seq,
325
- max_length=262144,
326
- truncation=True,
327
- return_tensors='pt',
328
- return_offsets_mapping=True,
329
- add_special_tokens=False
330
- )
331
-
332
- input_ids = inputs['input_ids'].to(device)
333
- attention_mask = inputs['attention_mask'].to(device)
334
- offset_mapping = inputs['offset_mapping'][0]
335
-
336
- # Build character-to-token mapping
337
- char_to_token = {}
338
- for i, (s, e) in enumerate(offset_mapping):
339
- for j in range(s, e):
340
- char_to_token[j] = i
341
-
342
- # Map chunk character ranges to token ranges
343
- token_chunk_ranges = []
344
- for start, end in chunk_ranges:
345
- token_start = char_to_token.get(start, 0)
346
- token_end = char_to_token.get(end - 1, 0) + 1
347
- token_chunk_ranges.append([token_start, token_end])
348
-
349
- # Get query token positions
350
- query_start_char = full_seq.index(question)
351
- query_end_char = query_start_char + len(question) - 1
352
- query_positions = list(range(
353
- char_to_token[query_start_char],
354
- char_to_token[query_end_char] + 1
355
- ))
356
- query_upper_bound = query_positions[-1] + 1
357
-
358
- # Forward pass with custom cache
359
- with torch.no_grad():
360
- # Initialize cache with query token positions
361
- past_kv = DynamicCacheWithQuery(query_indices=query_positions)
362
-
363
- # Run model forward pass
364
- output = model(input_ids, attention_mask, past_key_values=past_kv)
365
-
366
- # Extract query and key states from cache
367
- query_cache = output.past_key_values.query_cache
368
- key_cache = output.past_key_values.key_cache
369
-
370
- # Compute relevance scores
371
- scores = compute_qr_scores(
372
- query_cache, key_cache,
373
- qr_head_list, token_chunk_ranges, query_upper_bound
374
- )
375
-
376
- # Sort by scores (descending)
377
- sorted_indices = torch.argsort(scores, descending=True).cpu().tolist()
378
- ranked_ids = [paragraphs[i]['idx'] for i in sorted_indices]
379
- ranked_scores = [float(scores[i]) for i in sorted_indices]
380
-
381
- return ranked_ids, ranked_scores
382
- ```
383
-
384
- ## Model Configuration
385
-
386
- The model configuration includes the following QRRanker-specific parameters:
387
-
388
- | Parameter | Description |
389
- |-----------|-------------|
390
- | `qr_start_layer` | Starting layer index for QR heads |
391
- | `qr_end_layer` | Ending layer index for QR heads |
392
- | `qr_head_list` | List of (layer, head) tuples for top QR heads |
393
-
394
- ### Default Top-16 QR Heads
395
-
396
- ```
397
- 20-15, 21-11, 17-27, 23-10, 22-4, 21-10, 21-8, 21-18,
398
- 18-15, 18-19, 17-25, 17-17, 24-13, 17-4, 19-12, 21-31
399
- ```
400
-
401
- ## Command Line Usage
402
-
403
- ```bash
404
- # Basic inference
405
- python qr_ranker_inference.py \
406
- --base_model QRRanker \
407
- --data_path /path/to/data.json \
408
- --mode top16
409
-
410
- # With summary
411
- python qr_ranker_inference.py \
412
- --base_model QRRanker \
413
- --data_path /path/to/data.json \
414
- --mode top16 \
415
- --use_summary
416
-
417
-
418
- ### Arguments
419
-
420
- | Argument | Type | Default | Description |
421
- |----------|------|---------|-------------|
422
- | `--base_model` | str | required | Path to QRRanker model |
423
- | `--data_path` | str | required | Path to input data file |
424
- | `--output_dir` | str | `./outputs` | Output directory |
425
- | `--mode` | str | `top16` | Mode: `full` (all heads) or `top16` (selected heads) |
426
- | `--qr_head_list` | str | None | Custom QR head list |
427
- | `--use_summary` | flag | False | Use summary field in data |
428
-
429
-
430
-
431
- If you use our QRRanker, please kindly cite:
432
-
433
- ```bibtex
434
- @misc{li2026queryfocusedmemoryawarererankerlong,
435
- title={Query-focused and Memory-aware Reranker for Long Context Processing},
436
- author={Yuqing Li and Jiangnan Li and Mo Yu and Guoxuan Ding and Zheng Lin and Weiping Wang and Jie Zhou},
437
- year={2026},
438
- eprint={2602.12192},
439
- archivePrefix={arXiv},
440
- primaryClass={cs.CL},
441
- url={https://arxiv.org/abs/2602.12192},
442
- }
443
- ```
444
-
445
- ## License
446
-
447
- This project is licensed under the Apache 2.0 License.
448
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - hotpotqa/hotpot_qa
5
+ - dgslibisey/MuSiQue
6
+ - Aman279/Locomo
7
+ - Phospheneser/DetectiveQA
8
+ language:
9
+ - en
10
+ - zh
11
+ metrics:
12
+ - accuracy
13
+ - exact_match
14
+ - f1
15
+ - recall
16
+ base_model:
17
+ - Qwen/Qwen3-4B-Instruct-2507
18
+ pipeline_tag: text-ranking
19
+ tags:
20
+ - Rerank
21
+ - Memory
22
+ ---
23
+ # QRRanker: Query-focused and Memory-aware Reranker for Long Context Processing
24
+
25
+ <p align="center">
26
+ <a href="https://qdcassie-li.github.io/QRRanker/"><b>🌐 Project Page</b></a> |
27
+ <a href="https://arxiv.org/abs/2602.12192"><b>📄 Paper</b></a> |
28
+ <a href="https://huggingface.co/MindscapeRAG/QRRanker"><b>🤗 Models</b></a>
29
+ </p>
30
+
31
+ QRRanker is a lightweight reranking framework that leverages **Query-focused Retrieval (QR) heads** to produce continuous relevance scores, enabling effective listwise reranking with small-scale models.
32
+
33
+
34
+ ## Model Description
35
+
36
+ Built upon the existing analysis of retrieval heads in large language models, QRRanker trains models to estimate passage–query relevance using the attention scores of selected **Query-focused Retrieval (QR) heads**. These heads are identified through QR score computation on seed data and are particularly effective at capturing query-document relevance signals.
37
+
38
+ Our approach provides a **listwise solution** that leverages the holistic information within the entire candidate shortlist during ranking. It naturally produces **continuous relevance scores**, enabling training on arbitrary retrieval datasets without requiring Likert-scale supervision.
39
+
40
+ ### Key Features
41
+
42
+ - **Listwise Reranking**: Leverages holistic information within the entire candidate shortlist during ranking
43
+ - **Continuous Relevance Scores**: Enables training on arbitrary retrieval datasets without requiring Likert-scale supervision
44
+ - **Selective Head Usage**: Focuses on top-performing QR attention heads
45
+ - **Memory Enhancement**: Optional contextual summaries for improved accuracy on long narratives and dialogues
46
+
47
+ ## Quick Start
48
+
49
+
50
+ ### Basic Usage
51
+
52
+ ```python
53
+ import torch
54
+ from transformers import AutoModel, AutoConfig, AutoTokenizer
55
+
56
+ # Load model
57
+ config = AutoConfig.from_pretrained("MindscapeRAG/QRRanker", trust_remote_code=True)
58
+ model = AutoModel.from_pretrained(
59
+ "MindscapeRAG/QRRanker",
60
+ config=config,
61
+ torch_dtype=torch.float16,
62
+ trust_remote_code=True,
63
+ )
64
+ model.eval()
65
+
66
+ # Load tokenizer
67
+ tokenizer = AutoTokenizer.from_pretrained("MindscapeRAG/QRRanker", trust_remote_code=True)
68
+ ```
69
+
70
+ ## Input Data Format
71
+
72
+ Input data should be in JSON format. Each sample contains the following fields:
73
+
74
+ ```json
75
+ {
76
+ "id": "sample_001",
77
+ "question": "What is the capital of France?",
78
+ "answer": "Paris",
79
+ "paragraphs": [
80
+ {
81
+ "idx": 0,
82
+ "title": "France",
83
+ "paragraph_text": "Paris is the capital and largest city of France...",
84
+ "is_supporting": true
85
+ },
86
+ {
87
+ "idx": 1,
88
+ "title": "Germany",
89
+ "paragraph_text": "Berlin is the capital of Germany...",
90
+ "is_supporting": false
91
+ }
92
+ ],
93
+ "summary": "Optional summary text..."
94
+ }
95
+ ```
96
+
97
+ ### Field Description
98
+
99
+ | Field | Type | Required | Description |
100
+ |-------|------|----------|-------------|
101
+ | `id` | string | Yes | Unique sample identifier |
102
+ | `question` | string | Yes | User query/question |
103
+ | `answer` | string | No | Ground truth answer (for evaluation) |
104
+ | `paragraphs` | list | Yes | List of candidate paragraphs |
105
+ | `paragraphs[].idx` | int | Yes | Paragraph index |
106
+ | `paragraphs[].title` | string | No | Paragraph title |
107
+ | `paragraphs[].paragraph_text` | string | Yes | Paragraph content |
108
+ | `paragraphs[].is_supporting` | bool | No | Whether it's a supporting paragraph (for evaluation) |
109
+ | `summary` | string | No | Optional summary information |
110
+
111
+ ## Core Algorithm
112
+
113
+ ### 0. DynamicCacheWithQuery (Custom Cache Class)
114
+
115
+ This custom cache class is essential for QRRanker. It extends the standard `DynamicCache` to also store query states at specified positions.
116
+
117
+
118
+ ```python
119
+ from typing import Any, Dict, Optional, Tuple
120
+ from transformers.cache_utils import DynamicCache
121
+ import torch
122
+
123
+
124
+ class DynamicCacheWithQuery(DynamicCache):
125
+ """
126
+ Custom cache class for QRRanker that stores both key/value states and query states.
127
+ The query states are extracted at specified token positions for attention computation.
128
+ """
129
+
130
+ def __init__(self, query_indices=[]) -> None:
131
+ super().__init__()
132
+ self._query_indices = query_indices # Token indices where query states should be saved
133
+ self.query_cache = []
134
+
135
+ def update(
136
+ self,
137
+ key_states: torch.Tensor,
138
+ value_states: torch.Tensor,
139
+ layer_idx: int,
140
+ cache_kwargs: Optional[Dict[str, Any]] = None,
141
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
142
+ """
143
+ Updates the cache with new key_states, value_states, and optionally query_states.
144
+
145
+ Parameters:
146
+ key_states: New key states to cache [batch, num_kv_heads, seq_len, head_dim]
147
+ value_states: New value states to cache [batch, num_kv_heads, seq_len, head_dim]
148
+ layer_idx: Index of the layer
149
+ cache_kwargs: Optional dict containing 'query_states' to cache
150
+
151
+ Returns:
152
+ Tuple of (updated_key_states, updated_value_states)
153
+ """
154
+ # Update seen tokens count
155
+ if layer_idx == 0:
156
+ self._seen_tokens += key_states.shape[-2]
157
+
158
+ # Update key/value cache
159
+ if key_states is not None:
160
+ if len(self.key_cache) <= layer_idx:
161
+ for _ in range(len(self.key_cache), layer_idx):
162
+ self.key_cache.append(torch.tensor([]))
163
+ self.value_cache.append(torch.tensor([]))
164
+ self.key_cache.append(key_states)
165
+ self.value_cache.append(value_states)
166
+ elif not self.key_cache[layer_idx].numel():
167
+ self.key_cache[layer_idx] = key_states
168
+ self.value_cache[layer_idx] = value_states
169
+ else:
170
+ self.key_cache[layer_idx] = torch.cat(
171
+ [self.key_cache[layer_idx], key_states], dim=-2
172
+ )
173
+ self.value_cache[layer_idx] = torch.cat(
174
+ [self.value_cache[layer_idx], value_states], dim=-2
175
+ )
176
+
177
+ # Update query cache if query_states provided
178
+ if cache_kwargs is not None:
179
+ query_states = cache_kwargs.get("query_states", None)
180
+ else:
181
+ query_states = None
182
+
183
+ if query_states is not None:
184
+ if len(self.query_cache) <= layer_idx:
185
+ self.query_cache.append(query_states)
186
+ else:
187
+ self.query_cache[layer_idx] = torch.cat(
188
+ [self.query_cache[layer_idx], query_states], dim=-2
189
+ )
190
+
191
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
192
+ ```
193
+
194
+ ### 1. Attention Weight Computation
195
+
196
+ ```python
197
+ import math
198
+ import torch
199
+
200
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
201
+ """Expand key/value states to match the number of query heads."""
202
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
203
+ if n_rep == 1:
204
+ return hidden_states
205
+ hidden_states = hidden_states[:, :, None, :, :].expand(
206
+ batch, num_key_value_heads, n_rep, slen, head_dim
207
+ )
208
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
209
+
210
+
211
+ def get_causal_mask(attn_weights):
212
+ """Generate causal attention mask."""
213
+ query_len, seq_len = attn_weights.size(-2), attn_weights.size(-1)
214
+ causal_mask = torch.ones_like(attn_weights.transpose(-1, -2).squeeze(0))
215
+ causal_mask = torch.triu(causal_mask, diagonal=-(seq_len - query_len))
216
+ causal_mask = causal_mask.transpose(-1, -2)
217
+ causal_mask = (1 - causal_mask) * torch.finfo(causal_mask.dtype).min
218
+ return causal_mask
219
+
220
+
221
+ def get_attn_weights(key_states, query_states):
222
+ """Compute attention weights between query and key states."""
223
+ bsz, num_heads, q_len, head_dim = query_states.size()
224
+ num_key_value_heads = key_states.size(1)
225
+ num_key_value_groups = num_heads // num_key_value_heads
226
+ kv_seq_len = key_states.size(-2)
227
+
228
+ # Expand key states to match query heads
229
+ key_states = repeat_kv(key_states, num_key_value_groups)
230
+
231
+ # Scaled dot-product attention
232
+ scale = 1.0 / math.sqrt(head_dim)
233
+ scaled_queries = query_states * scale
234
+ attn_weights = torch.matmul(scaled_queries, key_states.transpose(2, 3))
235
+
236
+ # Apply causal mask
237
+ causal_mask = get_causal_mask(attn_weights).to(attn_weights.device)
238
+ attn_weights += causal_mask.unsqueeze(0)
239
+
240
+ # Softmax normalization
241
+ attn_lses = torch.logsumexp(attn_weights, dim=-1, keepdim=True)
242
+ attn_weights = torch.exp(attn_weights - attn_lses)
243
+
244
+ return attn_weights
245
+ ```
246
+
247
+ ### 2. QRRanker Score Computation
248
+
249
+ ```python
250
+ def compute_qr_scores(
251
+ query_cache,
252
+ key_cache,
253
+ qr_head_list,
254
+ chunk_ranges,
255
+ query_upper_bound,
256
+ ):
257
+ """
258
+ Compute QRRanker attention scores for document chunks.
259
+
260
+ Args:
261
+ query_cache: List of query states from each layer
262
+ key_cache: List of key states from each layer
263
+ qr_head_list: String of QR heads, e.g., "20-15,21-11,17-27,..."
264
+ chunk_ranges: List of [start, end] token positions for each chunk
265
+ query_upper_bound: Upper bound token position for query
266
+
267
+ Returns:
268
+ scores: Tensor of shape [num_chunks] with relevance scores
269
+ """
270
+ all_head_scores = []
271
+
272
+ for key_state, query_state in zip(key_cache, query_cache):
273
+ # Compute attention weights
274
+ attn_weights = get_attn_weights(
275
+ key_state[:, :, :query_upper_bound, :],
276
+ query_state
277
+ )
278
+ # Average over query positions
279
+ attn_weights = attn_weights.mean(dim=-2)
280
+
281
+ # Aggregate scores for each chunk
282
+ chunk_scores = []
283
+ for start, end in chunk_ranges:
284
+ chunk_scores.append(attn_weights[:, :, start:end].sum(dim=-1))
285
+ chunk_scores = torch.stack(chunk_scores, dim=2)
286
+ all_head_scores.append(chunk_scores)
287
+
288
+ # Stack all layers: [batch, num_layers, num_heads, num_chunks]
289
+ all_head_scores = torch.stack(all_head_scores, dim=1).float()
290
+
291
+ # Select specific QR heads
292
+ if qr_head_list is not None:
293
+ head_set = [tuple(map(int, h.split('-'))) for h in qr_head_list.split(',')]
294
+ indices = torch.tensor(head_set).to(all_head_scores.device)
295
+ layers, heads = indices[:, 0], indices[:, 1]
296
+ all_head_scores = all_head_scores[:, layers, heads, :]
297
+
298
+ # Sum over selected heads
299
+ scores = all_head_scores.sum(dim=1).squeeze(0)
300
+
301
+ return scores
302
+ ```
303
+
304
+ ### 3. Complete Inference Pipeline
305
+
306
+ ```python
307
+ from custom_cache_new import DynamicCacheWithQuery
308
+
309
+ def rerank_documents(model, tokenizer, question, paragraphs, qr_head_list, device):
310
+ """
311
+ Rerank documents based on QRRanker scores.
312
+
313
+ Args:
314
+ model: QRRanker model
315
+ tokenizer: Tokenizer
316
+ question: Query string
317
+ paragraphs: List of paragraph dicts with 'idx' and 'paragraph_text'
318
+ qr_head_list: QR head list string (e.g., "20-15,21-11,17-27,...")
319
+ device: torch device
320
+
321
+ Returns:
322
+ ranked_ids: List of paragraph IDs sorted by relevance
323
+ scores: Corresponding relevance scores
324
+ """
325
+ # Build input sequence
326
+ prompt_prefix = '<|im_start|>user\n'
327
+ retrieval_instruction = "Here are some retrieved chunks:\n\n"
328
+
329
+ chunk_part = prompt_prefix + retrieval_instruction
330
+ chunk_ranges = []
331
+
332
+ for i, p in enumerate(paragraphs):
333
+ text = p.get('title', '') + ': ' + p['paragraph_text']
334
+ chunk_part += f"[{i+1}]"
335
+ start = len(chunk_part)
336
+ chunk_part += ' ' + text.strip()
337
+ end = len(chunk_part)
338
+ chunk_ranges.append([start, end])
339
+ chunk_part += '\n\n'
340
+
341
+ query_part = f"Use the retrieved chunks to answer the user's query.\n\nQuery: {question}"
342
+ full_seq = chunk_part + query_part
343
+
344
+ # Tokenize with offset mapping
345
+ inputs = tokenizer(
346
+ full_seq,
347
+ max_length=262144,
348
+ truncation=True,
349
+ return_tensors='pt',
350
+ return_offsets_mapping=True,
351
+ add_special_tokens=False
352
+ )
353
+
354
+ input_ids = inputs['input_ids'].to(device)
355
+ attention_mask = inputs['attention_mask'].to(device)
356
+ offset_mapping = inputs['offset_mapping'][0]
357
+
358
+ # Build character-to-token mapping
359
+ char_to_token = {}
360
+ for i, (s, e) in enumerate(offset_mapping):
361
+ for j in range(s, e):
362
+ char_to_token[j] = i
363
+
364
+ # Map chunk character ranges to token ranges
365
+ token_chunk_ranges = []
366
+ for start, end in chunk_ranges:
367
+ token_start = char_to_token.get(start, 0)
368
+ token_end = char_to_token.get(end - 1, 0) + 1
369
+ token_chunk_ranges.append([token_start, token_end])
370
+
371
+ # Get query token positions
372
+ query_start_char = full_seq.index(question)
373
+ query_end_char = query_start_char + len(question) - 1
374
+ query_positions = list(range(
375
+ char_to_token[query_start_char],
376
+ char_to_token[query_end_char] + 1
377
+ ))
378
+ query_upper_bound = query_positions[-1] + 1
379
+
380
+ # Forward pass with custom cache
381
+ with torch.no_grad():
382
+ # Initialize cache with query token positions
383
+ past_kv = DynamicCacheWithQuery(query_indices=query_positions)
384
+
385
+ # Run model forward pass
386
+ output = model(input_ids, attention_mask, past_key_values=past_kv)
387
+
388
+ # Extract query and key states from cache
389
+ query_cache = output.past_key_values.query_cache
390
+ key_cache = output.past_key_values.key_cache
391
+
392
+ # Compute relevance scores
393
+ scores = compute_qr_scores(
394
+ query_cache, key_cache,
395
+ qr_head_list, token_chunk_ranges, query_upper_bound
396
+ )
397
+
398
+ # Sort by scores (descending)
399
+ sorted_indices = torch.argsort(scores, descending=True).cpu().tolist()
400
+ ranked_ids = [paragraphs[i]['idx'] for i in sorted_indices]
401
+ ranked_scores = [float(scores[i]) for i in sorted_indices]
402
+
403
+ return ranked_ids, ranked_scores
404
+ ```
405
+
406
+ ## Model Configuration
407
+
408
+ The model configuration includes the following QRRanker-specific parameters:
409
+
410
+ | Parameter | Description |
411
+ |-----------|-------------|
412
+ | `qr_start_layer` | Starting layer index for QR heads |
413
+ | `qr_end_layer` | Ending layer index for QR heads |
414
+ | `qr_head_list` | List of (layer, head) tuples for top QR heads |
415
+
416
+ ### Default Top-16 QR Heads
417
+
418
+ ```
419
+ 20-15, 21-11, 17-27, 23-10, 22-4, 21-10, 21-8, 21-18,
420
+ 18-15, 18-19, 17-25, 17-17, 24-13, 17-4, 19-12, 21-31
421
+ ```
422
+
423
+ ## Command Line Usage
424
+
425
+ ```bash
426
+ # Basic inference
427
+ python qr_ranker_inference.py \
428
+ --base_model QRRanker \
429
+ --data_path /path/to/data.json \
430
+ --mode top16
431
+
432
+ # With summary
433
+ python qr_ranker_inference.py \
434
+ --base_model QRRanker \
435
+ --data_path /path/to/data.json \
436
+ --mode top16 \
437
+ --use_summary
438
+
439
+
440
+ ### Arguments
441
+
442
+ | Argument | Type | Default | Description |
443
+ |----------|------|---------|-------------|
444
+ | `--base_model` | str | required | Path to QRRanker model |
445
+ | `--data_path` | str | required | Path to input data file |
446
+ | `--output_dir` | str | `./outputs` | Output directory |
447
+ | `--mode` | str | `top16` | Mode: `full` (all heads) or `top16` (selected heads) |
448
+ | `--qr_head_list` | str | None | Custom QR head list |
449
+ | `--use_summary` | flag | False | Use summary field in data |
450
+
451
+
452
+
453
+ If you use our QRRanker, please kindly cite:
454
+
455
+ ```bibtex
456
+ @misc{li2026queryfocusedmemoryawarererankerlong,
457
+ title={Query-focused and Memory-aware Reranker for Long Context Processing},
458
+ author={Yuqing Li and Jiangnan Li and Mo Yu and Guoxuan Ding and Zheng Lin and Weiping Wang and Jie Zhou},
459
+ year={2026},
460
+ eprint={2602.12192},
461
+ archivePrefix={arXiv},
462
+ primaryClass={cs.CL},
463
+ url={https://arxiv.org/abs/2602.12192},
464
+ }
465
+ ```
466
+
467
+ ## License
468
+
469
+ This project is licensed under the Apache 2.0 License.