File size: 3,866 Bytes
19a4f4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""IRouterLM Model - RAG Strategy Router Model."""

import torch
import torch.nn as nn
from transformers import PreTrainedModel, Qwen3Model

from .configuration_irouterlm import IRouterLMConfig


class IRouterLMModel(PreTrainedModel):
    """
    IRouterLM: Intelligent Router for RAG Strategy Selection.

    A Qwen3-0.6B based model fine-tuned for classifying queries
    into optimal RAG retrieval strategies.

    Strategies:
        0: MULTIMODAL_RERANK - Multimodal retrieval with reranking
        1: MULTIMODAL-SINGLE - Single-stage multimodal retrieval
        2: TEXT_RERANK - Text-only retrieval with reranking
        3: TEXT-SINGLE - Single-stage text retrieval
    """

    config_class = IRouterLMConfig
    _no_split_modules = ["Qwen3DecoderLayer"]

    def __init__(self, config: IRouterLMConfig):
        super().__init__(config)

        # Load base Qwen3 model
        self.transformer = Qwen3Model.from_pretrained(
            config.base_model_name,
            trust_remote_code=True,
        )

        # Classification head
        self.dropout = nn.Dropout(config.classifier_dropout)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # Initialize weights
        self.post_init()

    def _init_weights(self, module):
        """Initialize classifier weights."""
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor = None,
        labels: torch.Tensor = None,
        output_hidden_states: bool = None,
        return_dict: bool = True,
        **kwargs,
    ):
        """
        Forward pass for strategy classification.
        """
        # Get base model outputs
        outputs = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
        )

        # Mean pooling over sequence dimension
        hidden_states = outputs.last_hidden_state

        if attention_mask is not None:
            mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
            sum_hidden = torch.sum(hidden_states * mask_expanded, dim=1)
            sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
            pooled = sum_hidden / sum_mask
        else:
            pooled = hidden_states.mean(dim=1)

        # Classification
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)

        loss = None
        if labels is not None:
            loss = self._compute_loss(logits, labels)

        return {"loss": loss, "logits": logits}

    def _compute_loss(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """Compute weighted KL divergence loss for soft labels."""
        EPS = 1e-8
        reward_sum = labels.sum(dim=-1, keepdim=True)
        labels_normalized = labels / (reward_sum + EPS)
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
        sample_losses = -(labels_normalized * log_probs).sum(dim=-1)
        sample_weights = labels.max(dim=-1)[0]
        return (sample_losses * sample_weights).mean()

    def predict(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None):
        """
        Predict the best RAG strategy for given queries.
        """
        self.eval()
        with torch.no_grad():
            outputs = self.forward(input_ids, attention_mask)
            probs = torch.softmax(outputs["logits"], dim=-1)
            predictions = probs.argmax(dim=-1)

        return {
            "predictions": predictions,
            "probabilities": probs,
            "strategy_names": [self.config.strategy_names[p.item()] for p in predictions],
        }