|
|
--- |
|
|
license: mit |
|
|
language: |
|
|
- en |
|
|
library_name: transformers |
|
|
tags: |
|
|
- rag |
|
|
- router |
|
|
- multimodal |
|
|
- retrieval |
|
|
- query-routing |
|
|
- qwen3 |
|
|
pipeline_tag: text-classification |
|
|
datasets: |
|
|
- ananoymous/Wiki-ss |
|
|
- ananoymous/DUDE |
|
|
- ananoymous/TATDQA |
|
|
- ananoymous/ArxivQA |
|
|
- ananoymous/FinQA |
|
|
- ananoymous/FinReport |
|
|
- ananoymous/FinSlides |
|
|
- ananoymous/ConvFinQA |
|
|
- ananoymous/MP-DocVQA |
|
|
- ananoymous/SciQAG |
|
|
- ananoymous/VQAonBD |
|
|
--- |
|
|
|
|
|
# IRouterLM: Adaptive Query Routing for Multimodal RAG |
|
|
|
|
|
<p align="center"> |
|
|
<a href="https://github.com/ananoymous-submission/sigir26">GitHub</a> • |
|
|
<a href="https://hf.co/collections/ananoymous/irouterlm">Training Data</a> |
|
|
</p> |
|
|
|
|
|
> A lightweight query-aware router that dynamically selects the optimal retrieval modality and architecture per query. IRouterLM achieves **state-of-the-art accuracy (0.76 nDCG@5)** while reducing latency by **90%** compared to the strongest baseline. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
IRouterLM is a fine-tuned Qwen3-0.6B model that classifies queries into optimal RAG retrieval strategies. Given a user query, the model predicts which retrieval pipeline will yield the best results while balancing accuracy and latency. |
|
|
|
|
|
### Supported Strategies |
|
|
|
|
|
| Strategy ID | Strategy Name | Description | |
|
|
|-------------|--------------|-------------| |
|
|
| 0 | `MULTIMODAL_RERANK` | Multimodal dense retrieval + late-interaction reranking | |
|
|
| 1 | `MULTIMODAL-SINGLE` | Single-stage multimodal dense retrieval | |
|
|
| 2 | `TEXT_RERANK` | Text dense retrieval + late-interaction reranking | |
|
|
| 3 | `TEXT-SINGLE` | Single-stage text dense retrieval | |
|
|
|
|
|
## Quick Start |
|
|
|
|
|
```python |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
import torch |
|
|
|
|
|
# Load model and tokenizer |
|
|
model = AutoModel.from_pretrained("ananoymous/IRouterLM", trust_remote_code=True) |
|
|
tokenizer = AutoTokenizer.from_pretrained("ananoymous/IRouterLM") |
|
|
|
|
|
# Example query |
|
|
query = "What was the revenue growth in Q3 2024?" |
|
|
inputs = tokenizer(query, return_tensors="pt") |
|
|
|
|
|
# Get prediction |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probs = torch.softmax(outputs["logits"], dim=-1) |
|
|
prediction = probs.argmax(dim=-1).item() |
|
|
|
|
|
# Strategy mapping |
|
|
strategies = ["MULTIMODAL_RERANK", "MULTIMODAL-SINGLE", "TEXT_RERANK", "TEXT-SINGLE"] |
|
|
print(f"Predicted strategy: {strategies[prediction]}") |
|
|
print(f"Confidence: {probs[0][prediction]:.2%}") |
|
|
``` |
|
|
|
|
|
### Using the `predict` Method |
|
|
|
|
|
```python |
|
|
result = model.predict(inputs["input_ids"], inputs["attention_mask"]) |
|
|
print(f"Strategy: {result['strategy_names'][0]}") |
|
|
print(f"Probabilities: {result['probabilities']}") |
|
|
``` |
|
|
|
|
|
## Architecture |
|
|
|
|
|
- **Base Model**: Qwen3-0.6B |
|
|
- **Fine-tuning**: LoRA (rank=16, alpha=32) |
|
|
- **Target Modules**: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj |
|
|
- **Classification Head**: Mean pooling + Linear (1024 → 4) |
|
|
- **Training Loss**: Weighted KL Divergence with soft labels |
|
|
|
|
|
``` |
|
|
Query → Qwen3-0.6B (LoRA) → Mean Pooling → Classifier → Strategy Prediction |
|
|
``` |
|
|
|
|
|
## Training Details |
|
|
|
|
|
### Dataset |
|
|
|
|
|
The model was trained on 80,000+ queries from 11 benchmarks: |
|
|
|
|
|
| Domain | Datasets | |
|
|
|--------|----------| |
|
|
| Financial | FinReport, FinSlides, FinQA, ConvFinQA, TAT-DQA | |
|
|
| Scientific | ArxivQA, SciQAG | |
|
|
| General | Wiki-SS, MP-DocVQA, DUDE, VQAnBD, | |
|
|
|
|
|
### Hyperparameters |
|
|
|
|
|
| Parameter | Value | |
|
|
|-----------|-------| |
|
|
| Learning Rate | 1e-4 | |
|
|
| Batch Size | 16 | |
|
|
| Epochs | 2 | |
|
|
| Weight Decay | 0.01 | |
|
|
| Warmup Ratio | 0.1 | |
|
|
| Scheduler | Cosine | |
|
|
| Precision | bfloat16 | |
|
|
| λ (trade-off) | 0.0 (accuracy-focused) | |
|
|
|
|
|
## Intended Use |
|
|
|
|
|
IRouterLM is designed for: |
|
|
|
|
|
- **RAG Systems**: Automatically select the optimal retrieval strategy per query |
|
|
- **Document QA**: Route queries to text-only or multimodal pipelines based on query semantics |
|
|
- **Cost Optimization**: Reduce computational costs by avoiding expensive pipelines when simpler ones suffice |
|
|
|
|
|
### Limitations |
|
|
|
|
|
- Trained on English queries only |
|
|
- Optimized for document retrieval tasks (financial, scientific, general domains) |
|
|
|
|
|
## License |
|
|
|
|
|
MIT License |
|
|
|
|
|
## Acknowledgments |
|
|
|
|
|
This work builds on: |
|
|
- [Qwen3](https://huggingface.co/Qwen/Qwen3-0.6B-Base) for the base model |
|
|
- [ColPali](https://github.com/illuin-tech/colpali) for multimodal late-interaction retrieval |
|
|
- [PEFT](https://github.com/huggingface/peft) for efficient fine-tuning |