File size: 4,256 Bytes
e0f0463
 
 
 
 
17d1d85
 
 
 
 
 
 
 
56819f4
 
 
 
 
 
 
 
 
 
 
7cb0ffc
17d1d85
 
 
 
 
927b978
4995506
17d1d85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10274b3
17d1d85
10274b3
17d1d85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ef0550
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
---
license: mit
language:
- en
library_name: transformers
tags:
- rag
- router
- multimodal
- retrieval
- query-routing
- qwen3
pipeline_tag: text-classification
datasets:
- ananoymous/Wiki-ss
- ananoymous/DUDE
- ananoymous/TATDQA
- ananoymous/ArxivQA
- ananoymous/FinQA
- ananoymous/FinReport
- ananoymous/FinSlides
- ananoymous/ConvFinQA
- ananoymous/MP-DocVQA
- ananoymous/SciQAG
- ananoymous/VQAonBD
---

# IRouterLM: Adaptive Query Routing for Multimodal RAG

<p align="center">
  <a href="https://github.com/ananoymous-submission/sigir26">GitHub</a><a href="https://hf.co/collections/ananoymous/irouterlm">Training Data</a>
</p>

> A lightweight query-aware router that dynamically selects the optimal retrieval modality and architecture per query. IRouterLM achieves **state-of-the-art accuracy (0.76 nDCG@5)** while reducing latency by **90%** compared to the strongest baseline.

## Model Description

IRouterLM is a fine-tuned Qwen3-0.6B model that classifies queries into optimal RAG retrieval strategies. Given a user query, the model predicts which retrieval pipeline will yield the best results while balancing accuracy and latency.

### Supported Strategies

| Strategy ID | Strategy Name | Description |
|-------------|--------------|-------------|
| 0 | `MULTIMODAL_RERANK` | Multimodal dense retrieval + late-interaction reranking |
| 1 | `MULTIMODAL-SINGLE` | Single-stage multimodal dense retrieval |
| 2 | `TEXT_RERANK` | Text dense retrieval + late-interaction reranking |
| 3 | `TEXT-SINGLE` | Single-stage text dense retrieval |

## Quick Start

```python
from transformers import AutoModel, AutoTokenizer
import torch

# Load model and tokenizer
model = AutoModel.from_pretrained("ananoymous/IRouterLM", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ananoymous/IRouterLM")

# Example query
query = "What was the revenue growth in Q3 2024?"
inputs = tokenizer(query, return_tensors="pt")

# Get prediction
with torch.no_grad():
    outputs = model(**inputs)
    probs = torch.softmax(outputs["logits"], dim=-1)
    prediction = probs.argmax(dim=-1).item()

# Strategy mapping
strategies = ["MULTIMODAL_RERANK", "MULTIMODAL-SINGLE", "TEXT_RERANK", "TEXT-SINGLE"]
print(f"Predicted strategy: {strategies[prediction]}")
print(f"Confidence: {probs[0][prediction]:.2%}")
```

### Using the `predict` Method

```python
result = model.predict(inputs["input_ids"], inputs["attention_mask"])
print(f"Strategy: {result['strategy_names'][0]}")
print(f"Probabilities: {result['probabilities']}")
```

## Architecture

- **Base Model**: Qwen3-0.6B
- **Fine-tuning**: LoRA (rank=16, alpha=32)
- **Target Modules**: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj
- **Classification Head**: Mean pooling + Linear (1024 → 4)
- **Training Loss**: Weighted KL Divergence with soft labels

```
Query → Qwen3-0.6B (LoRA) → Mean Pooling → Classifier → Strategy Prediction
```

## Training Details

### Dataset

The model was trained on 80,000+ queries from 11 benchmarks:

| Domain | Datasets |
|--------|----------|
| Financial | FinReport, FinSlides, FinQA, ConvFinQA, TAT-DQA |
| Scientific | ArxivQA, SciQAG |
| General | Wiki-SS, MP-DocVQA, DUDE, VQAnBD,  |

### Hyperparameters

| Parameter | Value |
|-----------|-------|
| Learning Rate | 1e-4 |
| Batch Size | 16 |
| Epochs | 2 |
| Weight Decay | 0.01 |
| Warmup Ratio | 0.1 |
| Scheduler | Cosine |
| Precision | bfloat16 |
| λ (trade-off) | 0.0 (accuracy-focused) |

## Intended Use

IRouterLM is designed for:

- **RAG Systems**: Automatically select the optimal retrieval strategy per query
- **Document QA**: Route queries to text-only or multimodal pipelines based on query semantics
- **Cost Optimization**: Reduce computational costs by avoiding expensive pipelines when simpler ones suffice

### Limitations

- Trained on English queries only
- Optimized for document retrieval tasks (financial, scientific, general domains)

## License

MIT License

## Acknowledgments

This work builds on:
- [Qwen3](https://huggingface.co/Qwen/Qwen3-0.6B-Base) for the base model
- [ColPali](https://github.com/illuin-tech/colpali) for multimodal late-interaction retrieval
- [PEFT](https://github.com/huggingface/peft) for efficient fine-tuning