ananoymous commited on
Commit
19a4f4d
·
verified ·
1 Parent(s): 4aa8134

Upload IRouterLM model

Browse files
Files changed (3) hide show
  1. config.json +4 -0
  2. configuration_irouterlm.py +32 -0
  3. modeling_irouterlm.py +114 -0
config.json CHANGED
@@ -2,6 +2,10 @@
2
  "architectures": [
3
  "IRouterLMModel"
4
  ],
 
 
 
 
5
  "base_model_name": "Qwen/Qwen3-0.6B-Base",
6
  "classifier_dropout": 0.1,
7
  "dtype": "float32",
 
2
  "architectures": [
3
  "IRouterLMModel"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_irouterlm.IRouterLMConfig",
7
+ "AutoModel": "modeling_irouterlm.IRouterLMModel"
8
+ },
9
  "base_model_name": "Qwen/Qwen3-0.6B-Base",
10
  "classifier_dropout": 0.1,
11
  "dtype": "float32",
configuration_irouterlm.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """IRouterLM Configuration - RAG Strategy Router Model Configuration."""
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ STRATEGY_NAMES = [
7
+ "MULTIMODAL_RERANK",
8
+ "MULTIMODAL-SINGLE",
9
+ "TEXT_RERANK",
10
+ "TEXT-SINGLE",
11
+ ]
12
+
13
+
14
+ class IRouterLMConfig(PretrainedConfig):
15
+ """Configuration for IRouterLM - a RAG strategy router model."""
16
+
17
+ model_type = "irouterlm"
18
+
19
+ def __init__(
20
+ self,
21
+ base_model_name: str = "Qwen/Qwen3-0.6B-Base",
22
+ hidden_size: int = 1024,
23
+ num_labels: int = 4,
24
+ classifier_dropout: float = 0.1,
25
+ strategy_names: list = None,
26
+ **kwargs,
27
+ ):
28
+ super().__init__(num_labels=num_labels, **kwargs)
29
+ self.base_model_name = base_model_name
30
+ self.hidden_size = hidden_size
31
+ self.classifier_dropout = classifier_dropout
32
+ self.strategy_names = strategy_names or STRATEGY_NAMES
modeling_irouterlm.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """IRouterLM Model - RAG Strategy Router Model."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import PreTrainedModel, Qwen3Model
6
+
7
+ from .configuration_irouterlm import IRouterLMConfig
8
+
9
+
10
+ class IRouterLMModel(PreTrainedModel):
11
+ """
12
+ IRouterLM: Intelligent Router for RAG Strategy Selection.
13
+
14
+ A Qwen3-0.6B based model fine-tuned for classifying queries
15
+ into optimal RAG retrieval strategies.
16
+
17
+ Strategies:
18
+ 0: MULTIMODAL_RERANK - Multimodal retrieval with reranking
19
+ 1: MULTIMODAL-SINGLE - Single-stage multimodal retrieval
20
+ 2: TEXT_RERANK - Text-only retrieval with reranking
21
+ 3: TEXT-SINGLE - Single-stage text retrieval
22
+ """
23
+
24
+ config_class = IRouterLMConfig
25
+ _no_split_modules = ["Qwen3DecoderLayer"]
26
+
27
+ def __init__(self, config: IRouterLMConfig):
28
+ super().__init__(config)
29
+
30
+ # Load base Qwen3 model
31
+ self.transformer = Qwen3Model.from_pretrained(
32
+ config.base_model_name,
33
+ trust_remote_code=True,
34
+ )
35
+
36
+ # Classification head
37
+ self.dropout = nn.Dropout(config.classifier_dropout)
38
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
39
+
40
+ # Initialize weights
41
+ self.post_init()
42
+
43
+ def _init_weights(self, module):
44
+ """Initialize classifier weights."""
45
+ if isinstance(module, nn.Linear):
46
+ nn.init.normal_(module.weight, std=0.02)
47
+ if module.bias is not None:
48
+ nn.init.zeros_(module.bias)
49
+
50
+ def forward(
51
+ self,
52
+ input_ids: torch.Tensor,
53
+ attention_mask: torch.Tensor = None,
54
+ labels: torch.Tensor = None,
55
+ output_hidden_states: bool = None,
56
+ return_dict: bool = True,
57
+ **kwargs,
58
+ ):
59
+ """
60
+ Forward pass for strategy classification.
61
+ """
62
+ # Get base model outputs
63
+ outputs = self.transformer(
64
+ input_ids=input_ids,
65
+ attention_mask=attention_mask,
66
+ output_hidden_states=True,
67
+ )
68
+
69
+ # Mean pooling over sequence dimension
70
+ hidden_states = outputs.last_hidden_state
71
+
72
+ if attention_mask is not None:
73
+ mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
74
+ sum_hidden = torch.sum(hidden_states * mask_expanded, dim=1)
75
+ sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
76
+ pooled = sum_hidden / sum_mask
77
+ else:
78
+ pooled = hidden_states.mean(dim=1)
79
+
80
+ # Classification
81
+ pooled = self.dropout(pooled)
82
+ logits = self.classifier(pooled)
83
+
84
+ loss = None
85
+ if labels is not None:
86
+ loss = self._compute_loss(logits, labels)
87
+
88
+ return {"loss": loss, "logits": logits}
89
+
90
+ def _compute_loss(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
91
+ """Compute weighted KL divergence loss for soft labels."""
92
+ EPS = 1e-8
93
+ reward_sum = labels.sum(dim=-1, keepdim=True)
94
+ labels_normalized = labels / (reward_sum + EPS)
95
+ log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
96
+ sample_losses = -(labels_normalized * log_probs).sum(dim=-1)
97
+ sample_weights = labels.max(dim=-1)[0]
98
+ return (sample_losses * sample_weights).mean()
99
+
100
+ def predict(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None):
101
+ """
102
+ Predict the best RAG strategy for given queries.
103
+ """
104
+ self.eval()
105
+ with torch.no_grad():
106
+ outputs = self.forward(input_ids, attention_mask)
107
+ probs = torch.softmax(outputs["logits"], dim=-1)
108
+ predictions = probs.argmax(dim=-1)
109
+
110
+ return {
111
+ "predictions": predictions,
112
+ "probabilities": probs,
113
+ "strategy_names": [self.config.strategy_names[p.item()] for p in predictions],
114
+ }