import math from typing import Any, Dict, Optional import torch import torch.nn as nn from transformers import Qwen3ForCausalLM class QueritModel(Qwen3ForCausalLM): """Querit reranker based on Qwen3-Embedding-4B backbone with binary classification head.""" def __init__(self, config, use_lm_head: bool = False): super().__init__(config) hidden_size = self.config.hidden_size self.head = nn.Linear(hidden_size, 2) nn.init.normal_(self.head.weight, std=1e-4) nn.init.zeros_(self.head.bias) if not use_lm_head: self.lm_head = None def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: Optional[torch.Tensor] = None, scores: Optional[torch.Tensor] = None, qids: Optional[torch.Tensor] = None, ) -> Dict[str, Any]: outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, ) cls_hidden = outputs.last_hidden_state[:, -1, :] logits = self.head(cls_hidden) probs = torch.softmax(logits, dim=-1) pred_labels = torch.argmax(probs, dim=-1) rank_scores = self._compute_score(probs) loss = None if labels is not None and scores is not None: loss = self._pairwise_hinge_loss(rank_scores, scores, qids) return { "loss": loss, "qids": qids, "score": rank_scores, "pred_label": pred_labels, } def _pairwise_hinge_loss( self, logits: torch.Tensor, labels: torch.Tensor, qids: torch.Tensor, margin_weight: float = 0.8, topk: bool = False, pairdiff_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: qid_mask = (qids.unsqueeze(0) == qids.unsqueeze(1)).float() if topk: qid_mask = qid_mask * self._get_topk_mask(qids, logits.squeeze(-1), labels) batch_size = logits.shape[0] labels = labels.unsqueeze(1) score_pos = logits.expand(-1, batch_size) score_neg = score_pos.transpose(0, 1) pos = labels.expand(-1, batch_size) neg = pos.transpose(0, 1) if pairdiff_mask is not None: margin = (pos - neg + pairdiff_mask) * qid_mask * margin_weight else: margin = (pos - neg) * qid_mask * margin_weight pair_mask = (margin > 1e-6).float() score_diff = score_pos - score_neg margin_diff = margin + torch.clamp(-score_diff, min=-10.0) loss = torch.relu(margin_diff) * pair_mask return torch.sum(loss) / (torch.sum(pair_mask) + 1e-5) def _get_topk_mask(self, qids, logits, labels): flatten_qids = qids.view(-1) flatten_logits = logits.view(-1) flatten_labels = labels.view(-1) unique_qids = torch.unique(flatten_qids) batch_size = qids.shape[0] position_mask = torch.ones(batch_size, dtype=torch.float32, device=logits.device) for uq in unique_qids: indices = (flatten_qids == uq).nonzero(as_tuple=True)[0] cur_labels = flatten_labels[indices] valid_count = (cur_labels >= 0).sum().item() k = math.ceil(valid_count * 0.3) if k == 0: continue cur_logits = flatten_logits[indices] topk_idx = indices[cur_logits.argsort(descending=True)[:k]] position_mask[topk_idx] = 2.0 return position_mask.unsqueeze(-1).expand(batch_size, batch_size).transpose(0, 1) def _compute_score(self, probs: torch.Tensor) -> torch.Tensor: weights = torch.tensor([-1.0, 1.0], device=probs.device) return (probs * weights).sum(dim=-1, keepdim=True)