ananoymous commited on
Commit
17d1d85
·
verified ·
1 Parent(s): 8ac0de1

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +144 -4
README.md CHANGED
@@ -2,8 +2,148 @@
2
  license: mit
3
  language:
4
  - en
5
- base_model:
6
- - Qwen/Qwen3-0.6B
7
- pipeline_tag: visual-document-retrieval
8
  library_name: transformers
9
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  license: mit
3
  language:
4
  - en
 
 
 
5
  library_name: transformers
6
+ tags:
7
+ - rag
8
+ - router
9
+ - multimodal
10
+ - retrieval
11
+ - query-routing
12
+ - qwen3
13
+ datasets:
14
+ - ananoymous/irouterlm-training-data
15
+ pipeline_tag: text-classification
16
+ ---
17
+
18
+ # IRouterLM: Adaptive Query Routing for Multimodal RAG
19
+
20
+ <p align="center">
21
+ <a href="https://github.com/ananoymous/sigir26">Paper</a> •
22
+ <a href="https://github.com/ananoymous/sigir26">GitHub</a> •
23
+ <a href="https://huggingface.co/datasets/ananoymous/irouterlm-training-data">Training Data</a>
24
+ </p>
25
+
26
+ > A lightweight query-aware router that dynamically selects the optimal retrieval modality and architecture per query. IRouterLM achieves **state-of-the-art accuracy (0.76 nDCG@5)** while reducing latency by **90%** compared to the strongest baseline.
27
+
28
+ ## Model Description
29
+
30
+ IRouterLM is a fine-tuned Qwen3-0.6B model that classifies queries into optimal RAG retrieval strategies. Given a user query, the model predicts which retrieval pipeline will yield the best results while balancing accuracy and latency.
31
+
32
+ ### Supported Strategies
33
+
34
+ | Strategy ID | Strategy Name | Description |
35
+ |-------------|--------------|-------------|
36
+ | 0 | `MULTIMODAL_RERANK` | Multimodal dense retrieval + late-interaction reranking |
37
+ | 1 | `MULTIMODAL-SINGLE` | Single-stage multimodal dense retrieval |
38
+ | 2 | `TEXT_RERANK` | Text dense retrieval + late-interaction reranking |
39
+ | 3 | `TEXT-SINGLE` | Single-stage text dense retrieval |
40
+
41
+ ## Quick Start
42
+
43
+ ```python
44
+ from transformers import AutoModel, AutoTokenizer
45
+ import torch
46
+
47
+ # Load model and tokenizer
48
+ model = AutoModel.from_pretrained("ananoymous/IRouterLM", trust_remote_code=True)
49
+ tokenizer = AutoTokenizer.from_pretrained("ananoymous/IRouterLM")
50
+
51
+ # Example query
52
+ query = "What was the revenue growth in Q3 2024?"
53
+ inputs = tokenizer(query, return_tensors="pt")
54
+
55
+ # Get prediction
56
+ with torch.no_grad():
57
+ outputs = model(**inputs)
58
+ probs = torch.softmax(outputs["logits"], dim=-1)
59
+ prediction = probs.argmax(dim=-1).item()
60
+
61
+ # Strategy mapping
62
+ strategies = ["MULTIMODAL_RERANK", "MULTIMODAL-SINGLE", "TEXT_RERANK", "TEXT-SINGLE"]
63
+ print(f"Predicted strategy: {strategies[prediction]}")
64
+ print(f"Confidence: {probs[0][prediction]:.2%}")
65
+ ```
66
+
67
+ ### Using the `predict` Method
68
+
69
+ ```python
70
+ result = model.predict(inputs["input_ids"], inputs["attention_mask"])
71
+ print(f"Strategy: {result['strategy_names'][0]}")
72
+ print(f"Probabilities: {result['probabilities']}")
73
+ ```
74
+
75
+ ## Architecture
76
+
77
+ - **Base Model**: Qwen3-0.6B
78
+ - **Fine-tuning**: LoRA (rank=16, alpha=32)
79
+ - **Target Modules**: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj
80
+ - **Classification Head**: Mean pooling + Linear (1024 → 4)
81
+ - **Training Loss**: Weighted KL Divergence with soft labels
82
+
83
+ ```
84
+ Query → Qwen3-0.6B (LoRA) → Mean Pooling → Classifier → Strategy Prediction
85
+ ```
86
+
87
+ ## Training Details
88
+
89
+ ### Dataset
90
+
91
+ The model was trained on 80,000+ queries from 11 benchmarks:
92
+
93
+ | Domain | Datasets |
94
+ |--------|----------|
95
+ | Financial | FinReport, FinSlides, FinQA, ConvFinQA |
96
+ | Scientific | ArxivQA, SciQAG |
97
+ | General | Wiki-SS, MP-DocVQA, DUDE, VQAnBD, TAT-DQA |
98
+
99
+ ### Training Procedure
100
+
101
+ 1. **Oracle Label Generation**: Run all retrieval pipelines on training queries to collect nDCG@5 and latency metrics
102
+ 2. **Reward Computation**: `r(q, i) = (1 - λ) · nDCG(q, i) + λ · (1 - NormalizedLatency(q, i))`
103
+ 3. **Soft Label Training**: Train with weighted KL divergence loss using reward scores as soft labels
104
+
105
+ ### Hyperparameters
106
+
107
+ | Parameter | Value |
108
+ |-----------|-------|
109
+ | Learning Rate | 1e-4 |
110
+ | Batch Size | 16 |
111
+ | Epochs | 2 |
112
+ | Weight Decay | 0.01 |
113
+ | Warmup Ratio | 0.1 |
114
+ | Scheduler | Cosine |
115
+ | Precision | bfloat16 |
116
+ | λ (trade-off) | 0.0 (accuracy-focused) |
117
+
118
+ ## Performance
119
+
120
+ ### Latency
121
+
122
+ | Component | Time |
123
+ |-----------|------|
124
+ | Router Inference | ~15ms |
125
+
126
+ ## Intended Use
127
+
128
+ IRouterLM is designed for:
129
+
130
+ - **RAG Systems**: Automatically select the optimal retrieval strategy per query
131
+ - **Document QA**: Route queries to text-only or multimodal pipelines based on query semantics
132
+ - **Cost Optimization**: Reduce computational costs by avoiding expensive pipelines when simpler ones suffice
133
+
134
+ ### Limitations
135
+
136
+ - Trained on English queries only
137
+ - Optimized for document retrieval tasks (financial, scientific, general domains)
138
+ - Requires the corresponding retrieval pipelines to be available
139
+
140
+ ## License
141
+
142
+ MIT License
143
+
144
+ ## Acknowledgments
145
+
146
+ This work builds on:
147
+ - [Qwen3](https://huggingface.co/Qwen/Qwen3-0.6B-Base) for the base model
148
+ - [ColPali](https://github.com/illuin-tech/colpali) for multimodal late-interaction retrieval
149
+ - [PEFT](https://github.com/huggingface/peft) for efficient fine-tuning