Judychoieee commited on
Commit
6c7eda3
·
1 Parent(s): a5400b2

update readme

Browse files
Files changed (1) hide show
  1. README.md +448 -3
README.md CHANGED
@@ -1,3 +1,448 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # QRRanker: Query-focused and Memory-aware Reranker for Long Context Processing
2
+
3
+ <p align="center">
4
+ <a href="https://qdcassie-li.github.io/QRRanker/"><b>🌐 Project Page</b></a> |
5
+ <a href="https://arxiv.org/abs/2602.12192"><b>📄 Paper</b></a> |
6
+ <a href="https://huggingface.co/MindscapeRAG/QRRanker"><b>🤗 Models</b></a>
7
+ </p>
8
+
9
+ QRRanker is a lightweight reranking framework that leverages **Query-focused Retrieval (QR) heads** to produce continuous relevance scores, enabling effective listwise reranking with small-scale models.
10
+
11
+
12
+ ## Model Description
13
+
14
+ Built upon the existing analysis of retrieval heads in large language models, QRRanker trains models to estimate passage–query relevance using the attention scores of selected **Query-focused Retrieval (QR) heads**. These heads are identified through QR score computation on seed data and are particularly effective at capturing query-document relevance signals.
15
+
16
+ Our approach provides a **listwise solution** that leverages the holistic information within the entire candidate shortlist during ranking. It naturally produces **continuous relevance scores**, enabling training on arbitrary retrieval datasets without requiring Likert-scale supervision.
17
+
18
+ ### Key Features
19
+
20
+ - **Listwise Reranking**: Leverages holistic information within the entire candidate shortlist during ranking
21
+ - **Continuous Relevance Scores**: Enables training on arbitrary retrieval datasets without requiring Likert-scale supervision
22
+ - **Selective Head Usage**: Focuses on top-performing QR attention heads
23
+ - **Memory Enhancement**: Optional contextual summaries for improved accuracy on long narratives and dialogues
24
+
25
+ ## Quick Start
26
+
27
+
28
+ ### Basic Usage
29
+
30
+ ```python
31
+ import torch
32
+ from transformers import AutoModel, AutoConfig, AutoTokenizer
33
+
34
+ # Load model
35
+ config = AutoConfig.from_pretrained("MindscapeRAG/QRRanker", trust_remote_code=True)
36
+ model = AutoModel.from_pretrained(
37
+ "MindscapeRAG/QRRanker",
38
+ config=config,
39
+ torch_dtype=torch.float16,
40
+ trust_remote_code=True,
41
+ )
42
+ model.eval()
43
+
44
+ # Load tokenizer
45
+ tokenizer = AutoTokenizer.from_pretrained("MindscapeRAG/QRRanker", trust_remote_code=True)
46
+ ```
47
+
48
+ ## Input Data Format
49
+
50
+ Input data should be in JSON format. Each sample contains the following fields:
51
+
52
+ ```json
53
+ {
54
+ "id": "sample_001",
55
+ "question": "What is the capital of France?",
56
+ "answer": "Paris",
57
+ "paragraphs": [
58
+ {
59
+ "idx": 0,
60
+ "title": "France",
61
+ "paragraph_text": "Paris is the capital and largest city of France...",
62
+ "is_supporting": true
63
+ },
64
+ {
65
+ "idx": 1,
66
+ "title": "Germany",
67
+ "paragraph_text": "Berlin is the capital of Germany...",
68
+ "is_supporting": false
69
+ }
70
+ ],
71
+ "summary": "Optional summary text..."
72
+ }
73
+ ```
74
+
75
+ ### Field Description
76
+
77
+ | Field | Type | Required | Description |
78
+ |-------|------|----------|-------------|
79
+ | `id` | string | Yes | Unique sample identifier |
80
+ | `question` | string | Yes | User query/question |
81
+ | `answer` | string | No | Ground truth answer (for evaluation) |
82
+ | `paragraphs` | list | Yes | List of candidate paragraphs |
83
+ | `paragraphs[].idx` | int | Yes | Paragraph index |
84
+ | `paragraphs[].title` | string | No | Paragraph title |
85
+ | `paragraphs[].paragraph_text` | string | Yes | Paragraph content |
86
+ | `paragraphs[].is_supporting` | bool | No | Whether it's a supporting paragraph (for evaluation) |
87
+ | `summary` | string | No | Optional summary information |
88
+
89
+ ## Core Algorithm
90
+
91
+ ### 0. DynamicCacheWithQuery (Custom Cache Class)
92
+
93
+ This custom cache class is essential for QRRanker. It extends the standard `DynamicCache` to also store query states at specified positions.
94
+
95
+
96
+ ```python
97
+ from typing import Any, Dict, Optional, Tuple
98
+ from transformers.cache_utils import DynamicCache
99
+ import torch
100
+
101
+
102
+ class DynamicCacheWithQuery(DynamicCache):
103
+ """
104
+ Custom cache class for QRRanker that stores both key/value states and query states.
105
+ The query states are extracted at specified token positions for attention computation.
106
+ """
107
+
108
+ def __init__(self, query_indices=[]) -> None:
109
+ super().__init__()
110
+ self._query_indices = query_indices # Token indices where query states should be saved
111
+ self.query_cache = []
112
+
113
+ def update(
114
+ self,
115
+ key_states: torch.Tensor,
116
+ value_states: torch.Tensor,
117
+ layer_idx: int,
118
+ cache_kwargs: Optional[Dict[str, Any]] = None,
119
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
120
+ """
121
+ Updates the cache with new key_states, value_states, and optionally query_states.
122
+
123
+ Parameters:
124
+ key_states: New key states to cache [batch, num_kv_heads, seq_len, head_dim]
125
+ value_states: New value states to cache [batch, num_kv_heads, seq_len, head_dim]
126
+ layer_idx: Index of the layer
127
+ cache_kwargs: Optional dict containing 'query_states' to cache
128
+
129
+ Returns:
130
+ Tuple of (updated_key_states, updated_value_states)
131
+ """
132
+ # Update seen tokens count
133
+ if layer_idx == 0:
134
+ self._seen_tokens += key_states.shape[-2]
135
+
136
+ # Update key/value cache
137
+ if key_states is not None:
138
+ if len(self.key_cache) <= layer_idx:
139
+ for _ in range(len(self.key_cache), layer_idx):
140
+ self.key_cache.append(torch.tensor([]))
141
+ self.value_cache.append(torch.tensor([]))
142
+ self.key_cache.append(key_states)
143
+ self.value_cache.append(value_states)
144
+ elif not self.key_cache[layer_idx].numel():
145
+ self.key_cache[layer_idx] = key_states
146
+ self.value_cache[layer_idx] = value_states
147
+ else:
148
+ self.key_cache[layer_idx] = torch.cat(
149
+ [self.key_cache[layer_idx], key_states], dim=-2
150
+ )
151
+ self.value_cache[layer_idx] = torch.cat(
152
+ [self.value_cache[layer_idx], value_states], dim=-2
153
+ )
154
+
155
+ # Update query cache if query_states provided
156
+ if cache_kwargs is not None:
157
+ query_states = cache_kwargs.get("query_states", None)
158
+ else:
159
+ query_states = None
160
+
161
+ if query_states is not None:
162
+ if len(self.query_cache) <= layer_idx:
163
+ self.query_cache.append(query_states)
164
+ else:
165
+ self.query_cache[layer_idx] = torch.cat(
166
+ [self.query_cache[layer_idx], query_states], dim=-2
167
+ )
168
+
169
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
170
+ ```
171
+
172
+ ### 1. Attention Weight Computation
173
+
174
+ ```python
175
+ import math
176
+ import torch
177
+
178
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
179
+ """Expand key/value states to match the number of query heads."""
180
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
181
+ if n_rep == 1:
182
+ return hidden_states
183
+ hidden_states = hidden_states[:, :, None, :, :].expand(
184
+ batch, num_key_value_heads, n_rep, slen, head_dim
185
+ )
186
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
187
+
188
+
189
+ def get_causal_mask(attn_weights):
190
+ """Generate causal attention mask."""
191
+ query_len, seq_len = attn_weights.size(-2), attn_weights.size(-1)
192
+ causal_mask = torch.ones_like(attn_weights.transpose(-1, -2).squeeze(0))
193
+ causal_mask = torch.triu(causal_mask, diagonal=-(seq_len - query_len))
194
+ causal_mask = causal_mask.transpose(-1, -2)
195
+ causal_mask = (1 - causal_mask) * torch.finfo(causal_mask.dtype).min
196
+ return causal_mask
197
+
198
+
199
+ def get_attn_weights(key_states, query_states):
200
+ """Compute attention weights between query and key states."""
201
+ bsz, num_heads, q_len, head_dim = query_states.size()
202
+ num_key_value_heads = key_states.size(1)
203
+ num_key_value_groups = num_heads // num_key_value_heads
204
+ kv_seq_len = key_states.size(-2)
205
+
206
+ # Expand key states to match query heads
207
+ key_states = repeat_kv(key_states, num_key_value_groups)
208
+
209
+ # Scaled dot-product attention
210
+ scale = 1.0 / math.sqrt(head_dim)
211
+ scaled_queries = query_states * scale
212
+ attn_weights = torch.matmul(scaled_queries, key_states.transpose(2, 3))
213
+
214
+ # Apply causal mask
215
+ causal_mask = get_causal_mask(attn_weights).to(attn_weights.device)
216
+ attn_weights += causal_mask.unsqueeze(0)
217
+
218
+ # Softmax normalization
219
+ attn_lses = torch.logsumexp(attn_weights, dim=-1, keepdim=True)
220
+ attn_weights = torch.exp(attn_weights - attn_lses)
221
+
222
+ return attn_weights
223
+ ```
224
+
225
+ ### 2. QRRanker Score Computation
226
+
227
+ ```python
228
+ def compute_qr_scores(
229
+ query_cache,
230
+ key_cache,
231
+ qr_head_list,
232
+ chunk_ranges,
233
+ query_upper_bound,
234
+ ):
235
+ """
236
+ Compute QRRanker attention scores for document chunks.
237
+
238
+ Args:
239
+ query_cache: List of query states from each layer
240
+ key_cache: List of key states from each layer
241
+ qr_head_list: String of QR heads, e.g., "20-15,21-11,17-27,..."
242
+ chunk_ranges: List of [start, end] token positions for each chunk
243
+ query_upper_bound: Upper bound token position for query
244
+
245
+ Returns:
246
+ scores: Tensor of shape [num_chunks] with relevance scores
247
+ """
248
+ all_head_scores = []
249
+
250
+ for key_state, query_state in zip(key_cache, query_cache):
251
+ # Compute attention weights
252
+ attn_weights = get_attn_weights(
253
+ key_state[:, :, :query_upper_bound, :],
254
+ query_state
255
+ )
256
+ # Average over query positions
257
+ attn_weights = attn_weights.mean(dim=-2)
258
+
259
+ # Aggregate scores for each chunk
260
+ chunk_scores = []
261
+ for start, end in chunk_ranges:
262
+ chunk_scores.append(attn_weights[:, :, start:end].sum(dim=-1))
263
+ chunk_scores = torch.stack(chunk_scores, dim=2)
264
+ all_head_scores.append(chunk_scores)
265
+
266
+ # Stack all layers: [batch, num_layers, num_heads, num_chunks]
267
+ all_head_scores = torch.stack(all_head_scores, dim=1).float()
268
+
269
+ # Select specific QR heads
270
+ if qr_head_list is not None:
271
+ head_set = [tuple(map(int, h.split('-'))) for h in qr_head_list.split(',')]
272
+ indices = torch.tensor(head_set).to(all_head_scores.device)
273
+ layers, heads = indices[:, 0], indices[:, 1]
274
+ all_head_scores = all_head_scores[:, layers, heads, :]
275
+
276
+ # Sum over selected heads
277
+ scores = all_head_scores.sum(dim=1).squeeze(0)
278
+
279
+ return scores
280
+ ```
281
+
282
+ ### 3. Complete Inference Pipeline
283
+
284
+ ```python
285
+ from custom_cache_new import DynamicCacheWithQuery
286
+
287
+ def rerank_documents(model, tokenizer, question, paragraphs, qr_head_list, device):
288
+ """
289
+ Rerank documents based on QRRanker scores.
290
+
291
+ Args:
292
+ model: QRRanker model
293
+ tokenizer: Tokenizer
294
+ question: Query string
295
+ paragraphs: List of paragraph dicts with 'idx' and 'paragraph_text'
296
+ qr_head_list: QR head list string (e.g., "20-15,21-11,17-27,...")
297
+ device: torch device
298
+
299
+ Returns:
300
+ ranked_ids: List of paragraph IDs sorted by relevance
301
+ scores: Corresponding relevance scores
302
+ """
303
+ # Build input sequence
304
+ prompt_prefix = '<|im_start|>user\n'
305
+ retrieval_instruction = "Here are some retrieved chunks:\n\n"
306
+
307
+ chunk_part = prompt_prefix + retrieval_instruction
308
+ chunk_ranges = []
309
+
310
+ for i, p in enumerate(paragraphs):
311
+ text = p.get('title', '') + ': ' + p['paragraph_text']
312
+ chunk_part += f"[{i+1}]"
313
+ start = len(chunk_part)
314
+ chunk_part += ' ' + text.strip()
315
+ end = len(chunk_part)
316
+ chunk_ranges.append([start, end])
317
+ chunk_part += '\n\n'
318
+
319
+ query_part = f"Use the retrieved chunks to answer the user's query.\n\nQuery: {question}"
320
+ full_seq = chunk_part + query_part
321
+
322
+ # Tokenize with offset mapping
323
+ inputs = tokenizer(
324
+ full_seq,
325
+ max_length=262144,
326
+ truncation=True,
327
+ return_tensors='pt',
328
+ return_offsets_mapping=True,
329
+ add_special_tokens=False
330
+ )
331
+
332
+ input_ids = inputs['input_ids'].to(device)
333
+ attention_mask = inputs['attention_mask'].to(device)
334
+ offset_mapping = inputs['offset_mapping'][0]
335
+
336
+ # Build character-to-token mapping
337
+ char_to_token = {}
338
+ for i, (s, e) in enumerate(offset_mapping):
339
+ for j in range(s, e):
340
+ char_to_token[j] = i
341
+
342
+ # Map chunk character ranges to token ranges
343
+ token_chunk_ranges = []
344
+ for start, end in chunk_ranges:
345
+ token_start = char_to_token.get(start, 0)
346
+ token_end = char_to_token.get(end - 1, 0) + 1
347
+ token_chunk_ranges.append([token_start, token_end])
348
+
349
+ # Get query token positions
350
+ query_start_char = full_seq.index(question)
351
+ query_end_char = query_start_char + len(question) - 1
352
+ query_positions = list(range(
353
+ char_to_token[query_start_char],
354
+ char_to_token[query_end_char] + 1
355
+ ))
356
+ query_upper_bound = query_positions[-1] + 1
357
+
358
+ # Forward pass with custom cache
359
+ with torch.no_grad():
360
+ # Initialize cache with query token positions
361
+ past_kv = DynamicCacheWithQuery(query_indices=query_positions)
362
+
363
+ # Run model forward pass
364
+ output = model(input_ids, attention_mask, past_key_values=past_kv)
365
+
366
+ # Extract query and key states from cache
367
+ query_cache = output.past_key_values.query_cache
368
+ key_cache = output.past_key_values.key_cache
369
+
370
+ # Compute relevance scores
371
+ scores = compute_qr_scores(
372
+ query_cache, key_cache,
373
+ qr_head_list, token_chunk_ranges, query_upper_bound
374
+ )
375
+
376
+ # Sort by scores (descending)
377
+ sorted_indices = torch.argsort(scores, descending=True).cpu().tolist()
378
+ ranked_ids = [paragraphs[i]['idx'] for i in sorted_indices]
379
+ ranked_scores = [float(scores[i]) for i in sorted_indices]
380
+
381
+ return ranked_ids, ranked_scores
382
+ ```
383
+
384
+ ## Model Configuration
385
+
386
+ The model configuration includes the following QRRanker-specific parameters:
387
+
388
+ | Parameter | Description |
389
+ |-----------|-------------|
390
+ | `qr_start_layer` | Starting layer index for QR heads |
391
+ | `qr_end_layer` | Ending layer index for QR heads |
392
+ | `qr_head_list` | List of (layer, head) tuples for top QR heads |
393
+
394
+ ### Default Top-16 QR Heads
395
+
396
+ ```
397
+ 20-15, 21-11, 17-27, 23-10, 22-4, 21-10, 21-8, 21-18,
398
+ 18-15, 18-19, 17-25, 17-17, 24-13, 17-4, 19-12, 21-31
399
+ ```
400
+
401
+ ## Command Line Usage
402
+
403
+ ```bash
404
+ # Basic inference
405
+ python qr_ranker_inference.py \
406
+ --base_model QRRanker \
407
+ --data_path /path/to/data.json \
408
+ --mode top16
409
+
410
+ # With summary
411
+ python qr_ranker_inference.py \
412
+ --base_model QRRanker \
413
+ --data_path /path/to/data.json \
414
+ --mode top16 \
415
+ --use_summary
416
+
417
+
418
+ ### Arguments
419
+
420
+ | Argument | Type | Default | Description |
421
+ |----------|------|---------|-------------|
422
+ | `--base_model` | str | required | Path to QRRanker model |
423
+ | `--data_path` | str | required | Path to input data file |
424
+ | `--output_dir` | str | `./outputs` | Output directory |
425
+ | `--mode` | str | `top16` | Mode: `full` (all heads) or `top16` (selected heads) |
426
+ | `--qr_head_list` | str | None | Custom QR head list |
427
+ | `--use_summary` | flag | False | Use summary field in data |
428
+
429
+
430
+
431
+ If you use our QRRanker, please kindly cite:
432
+
433
+ ```bibtex
434
+ @misc{li2026queryfocusedmemoryawarererankerlong,
435
+ title={Query-focused and Memory-aware Reranker for Long Context Processing},
436
+ author={Yuqing Li and Jiangnan Li and Mo Yu and Guoxuan Ding and Zheng Lin and Weiping Wang and Jie Zhou},
437
+ year={2026},
438
+ eprint={2602.12192},
439
+ archivePrefix={arXiv},
440
+ primaryClass={cs.CL},
441
+ url={https://arxiv.org/abs/2602.12192},
442
+ }
443
+ ```
444
+
445
+ ## License
446
+
447
+ This project is licensed under the Apache 2.0 License.
448
+