IRouterLM / README.md
ananoymous's picture
Update README.md
7cb0ffc verified
---
license: mit
language:
- en
library_name: transformers
tags:
- rag
- router
- multimodal
- retrieval
- query-routing
- qwen3
pipeline_tag: text-classification
datasets:
- ananoymous/Wiki-ss
- ananoymous/DUDE
- ananoymous/TATDQA
- ananoymous/ArxivQA
- ananoymous/FinQA
- ananoymous/FinReport
- ananoymous/FinSlides
- ananoymous/ConvFinQA
- ananoymous/MP-DocVQA
- ananoymous/SciQAG
- ananoymous/VQAonBD
---
# IRouterLM: Adaptive Query Routing for Multimodal RAG
<p align="center">
<a href="https://github.com/ananoymous-submission/sigir26">GitHub</a>
<a href="https://hf.co/collections/ananoymous/irouterlm">Training Data</a>
</p>
> A lightweight query-aware router that dynamically selects the optimal retrieval modality and architecture per query. IRouterLM achieves **state-of-the-art accuracy (0.76 nDCG@5)** while reducing latency by **90%** compared to the strongest baseline.
## Model Description
IRouterLM is a fine-tuned Qwen3-0.6B model that classifies queries into optimal RAG retrieval strategies. Given a user query, the model predicts which retrieval pipeline will yield the best results while balancing accuracy and latency.
### Supported Strategies
| Strategy ID | Strategy Name | Description |
|-------------|--------------|-------------|
| 0 | `MULTIMODAL_RERANK` | Multimodal dense retrieval + late-interaction reranking |
| 1 | `MULTIMODAL-SINGLE` | Single-stage multimodal dense retrieval |
| 2 | `TEXT_RERANK` | Text dense retrieval + late-interaction reranking |
| 3 | `TEXT-SINGLE` | Single-stage text dense retrieval |
## Quick Start
```python
from transformers import AutoModel, AutoTokenizer
import torch
# Load model and tokenizer
model = AutoModel.from_pretrained("ananoymous/IRouterLM", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ananoymous/IRouterLM")
# Example query
query = "What was the revenue growth in Q3 2024?"
inputs = tokenizer(query, return_tensors="pt")
# Get prediction
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs["logits"], dim=-1)
prediction = probs.argmax(dim=-1).item()
# Strategy mapping
strategies = ["MULTIMODAL_RERANK", "MULTIMODAL-SINGLE", "TEXT_RERANK", "TEXT-SINGLE"]
print(f"Predicted strategy: {strategies[prediction]}")
print(f"Confidence: {probs[0][prediction]:.2%}")
```
### Using the `predict` Method
```python
result = model.predict(inputs["input_ids"], inputs["attention_mask"])
print(f"Strategy: {result['strategy_names'][0]}")
print(f"Probabilities: {result['probabilities']}")
```
## Architecture
- **Base Model**: Qwen3-0.6B
- **Fine-tuning**: LoRA (rank=16, alpha=32)
- **Target Modules**: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj
- **Classification Head**: Mean pooling + Linear (1024 → 4)
- **Training Loss**: Weighted KL Divergence with soft labels
```
Query → Qwen3-0.6B (LoRA) → Mean Pooling → Classifier → Strategy Prediction
```
## Training Details
### Dataset
The model was trained on 80,000+ queries from 11 benchmarks:
| Domain | Datasets |
|--------|----------|
| Financial | FinReport, FinSlides, FinQA, ConvFinQA, TAT-DQA |
| Scientific | ArxivQA, SciQAG |
| General | Wiki-SS, MP-DocVQA, DUDE, VQAnBD, |
### Hyperparameters
| Parameter | Value |
|-----------|-------|
| Learning Rate | 1e-4 |
| Batch Size | 16 |
| Epochs | 2 |
| Weight Decay | 0.01 |
| Warmup Ratio | 0.1 |
| Scheduler | Cosine |
| Precision | bfloat16 |
| λ (trade-off) | 0.0 (accuracy-focused) |
## Intended Use
IRouterLM is designed for:
- **RAG Systems**: Automatically select the optimal retrieval strategy per query
- **Document QA**: Route queries to text-only or multimodal pipelines based on query semantics
- **Cost Optimization**: Reduce computational costs by avoiding expensive pipelines when simpler ones suffice
### Limitations
- Trained on English queries only
- Optimized for document retrieval tasks (financial, scientific, general domains)
## License
MIT License
## Acknowledgments
This work builds on:
- [Qwen3](https://huggingface.co/Qwen/Qwen3-0.6B-Base) for the base model
- [ColPali](https://github.com/illuin-tech/colpali) for multimodal late-interaction retrieval
- [PEFT](https://github.com/huggingface/peft) for efficient fine-tuning