"""IRouterLM Model - RAG Strategy Router Model.""" import torch import torch.nn as nn from transformers import PreTrainedModel, Qwen3Model from .configuration_irouterlm import IRouterLMConfig class IRouterLMModel(PreTrainedModel): """ IRouterLM: Intelligent Router for RAG Strategy Selection. A Qwen3-0.6B based model fine-tuned for classifying queries into optimal RAG retrieval strategies. Strategies: 0: MULTIMODAL_RERANK - Multimodal retrieval with reranking 1: MULTIMODAL-SINGLE - Single-stage multimodal retrieval 2: TEXT_RERANK - Text-only retrieval with reranking 3: TEXT-SINGLE - Single-stage text retrieval """ config_class = IRouterLMConfig _no_split_modules = ["Qwen3DecoderLayer"] def __init__(self, config: IRouterLMConfig): super().__init__(config) # Load base Qwen3 model self.transformer = Qwen3Model.from_pretrained( config.base_model_name, trust_remote_code=True, ) # Classification head self.dropout = nn.Dropout(config.classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights self.post_init() def _init_weights(self, module): """Initialize classifier weights.""" if isinstance(module, nn.Linear): nn.init.normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None, labels: torch.Tensor = None, output_hidden_states: bool = None, return_dict: bool = True, **kwargs, ): """ Forward pass for strategy classification. """ # Get base model outputs outputs = self.transformer( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, ) # Mean pooling over sequence dimension hidden_states = outputs.last_hidden_state if attention_mask is not None: mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float() sum_hidden = torch.sum(hidden_states * mask_expanded, dim=1) sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9) pooled = sum_hidden / sum_mask else: pooled = hidden_states.mean(dim=1) # Classification pooled = self.dropout(pooled) logits = self.classifier(pooled) loss = None if labels is not None: loss = self._compute_loss(logits, labels) return {"loss": loss, "logits": logits} def _compute_loss(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """Compute weighted KL divergence loss for soft labels.""" EPS = 1e-8 reward_sum = labels.sum(dim=-1, keepdim=True) labels_normalized = labels / (reward_sum + EPS) log_probs = torch.nn.functional.log_softmax(logits, dim=-1) sample_losses = -(labels_normalized * log_probs).sum(dim=-1) sample_weights = labels.max(dim=-1)[0] return (sample_losses * sample_weights).mean() def predict(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None): """ Predict the best RAG strategy for given queries. """ self.eval() with torch.no_grad(): outputs = self.forward(input_ids, attention_mask) probs = torch.softmax(outputs["logits"], dim=-1) predictions = probs.argmax(dim=-1) return { "predictions": predictions, "probabilities": probs, "strategy_names": [self.config.strategy_names[p.item()] for p in predictions], }