Transformers
Safetensors
qwen3
text-generation-inference
Querit-4B / modeling_querit_4b.py
moshesbeta's picture
Upload 8 files
5fade51 verified
Raw
History Blame Contribute Delete
3.77 kB
import math
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
from transformers import Qwen3ForCausalLM
class QueritModel(Qwen3ForCausalLM):
"""Querit reranker based on Qwen3-Embedding-4B backbone with binary classification head."""
def __init__(self, config, use_lm_head: bool = False):
super().__init__(config)
hidden_size = self.config.hidden_size
self.head = nn.Linear(hidden_size, 2)
nn.init.normal_(self.head.weight, std=1e-4)
nn.init.zeros_(self.head.bias)
if not use_lm_head:
self.lm_head = None
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None,
scores: Optional[torch.Tensor] = None,
qids: Optional[torch.Tensor] = None,
) -> Dict[str, Any]:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
)
cls_hidden = outputs.last_hidden_state[:, -1, :]
logits = self.head(cls_hidden)
probs = torch.softmax(logits, dim=-1)
pred_labels = torch.argmax(probs, dim=-1)
rank_scores = self._compute_score(probs)
loss = None
if labels is not None and scores is not None:
loss = self._pairwise_hinge_loss(rank_scores, scores, qids)
return {
"loss": loss,
"qids": qids,
"score": rank_scores,
"pred_label": pred_labels,
}
def _pairwise_hinge_loss(
self,
logits: torch.Tensor,
labels: torch.Tensor,
qids: torch.Tensor,
margin_weight: float = 0.8,
topk: bool = False,
pairdiff_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qid_mask = (qids.unsqueeze(0) == qids.unsqueeze(1)).float()
if topk:
qid_mask = qid_mask * self._get_topk_mask(qids, logits.squeeze(-1), labels)
batch_size = logits.shape[0]
labels = labels.unsqueeze(1)
score_pos = logits.expand(-1, batch_size)
score_neg = score_pos.transpose(0, 1)
pos = labels.expand(-1, batch_size)
neg = pos.transpose(0, 1)
if pairdiff_mask is not None:
margin = (pos - neg + pairdiff_mask) * qid_mask * margin_weight
else:
margin = (pos - neg) * qid_mask * margin_weight
pair_mask = (margin > 1e-6).float()
score_diff = score_pos - score_neg
margin_diff = margin + torch.clamp(-score_diff, min=-10.0)
loss = torch.relu(margin_diff) * pair_mask
return torch.sum(loss) / (torch.sum(pair_mask) + 1e-5)
def _get_topk_mask(self, qids, logits, labels):
flatten_qids = qids.view(-1)
flatten_logits = logits.view(-1)
flatten_labels = labels.view(-1)
unique_qids = torch.unique(flatten_qids)
batch_size = qids.shape[0]
position_mask = torch.ones(batch_size, dtype=torch.float32, device=logits.device)
for uq in unique_qids:
indices = (flatten_qids == uq).nonzero(as_tuple=True)[0]
cur_labels = flatten_labels[indices]
valid_count = (cur_labels >= 0).sum().item()
k = math.ceil(valid_count * 0.3)
if k == 0:
continue
cur_logits = flatten_logits[indices]
topk_idx = indices[cur_logits.argsort(descending=True)[:k]]
position_mask[topk_idx] = 2.0
return position_mask.unsqueeze(-1).expand(batch_size, batch_size).transpose(0, 1)
def _compute_score(self, probs: torch.Tensor) -> torch.Tensor:
weights = torch.tensor([-1.0, 1.0], device=probs.device)
return (probs * weights).sum(dim=-1, keepdim=True)