Instructions to use Querit/Querit-4B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Querit/Querit-4B with Transformers:
# Load model directly from transformers import AutoTokenizer, AutoModel tokenizer = AutoTokenizer.from_pretrained("Querit/Querit-4B") model = AutoModel.from_pretrained("Querit/Querit-4B") - Notebooks
- Google Colab
- Kaggle
| import math | |
| from typing import Any, Dict, Optional | |
| import torch | |
| import torch.nn as nn | |
| from transformers import Qwen3ForCausalLM | |
| class QueritModel(Qwen3ForCausalLM): | |
| """Querit reranker based on Qwen3-Embedding-4B backbone with binary classification head.""" | |
| def __init__(self, config, use_lm_head: bool = False): | |
| super().__init__(config) | |
| hidden_size = self.config.hidden_size | |
| self.head = nn.Linear(hidden_size, 2) | |
| nn.init.normal_(self.head.weight, std=1e-4) | |
| nn.init.zeros_(self.head.bias) | |
| if not use_lm_head: | |
| self.lm_head = None | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| labels: Optional[torch.Tensor] = None, | |
| scores: Optional[torch.Tensor] = None, | |
| qids: Optional[torch.Tensor] = None, | |
| ) -> Dict[str, Any]: | |
| outputs = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| ) | |
| cls_hidden = outputs.last_hidden_state[:, -1, :] | |
| logits = self.head(cls_hidden) | |
| probs = torch.softmax(logits, dim=-1) | |
| pred_labels = torch.argmax(probs, dim=-1) | |
| rank_scores = self._compute_score(probs) | |
| loss = None | |
| if labels is not None and scores is not None: | |
| loss = self._pairwise_hinge_loss(rank_scores, scores, qids) | |
| return { | |
| "loss": loss, | |
| "qids": qids, | |
| "score": rank_scores, | |
| "pred_label": pred_labels, | |
| } | |
| def _pairwise_hinge_loss( | |
| self, | |
| logits: torch.Tensor, | |
| labels: torch.Tensor, | |
| qids: torch.Tensor, | |
| margin_weight: float = 0.8, | |
| topk: bool = False, | |
| pairdiff_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| qid_mask = (qids.unsqueeze(0) == qids.unsqueeze(1)).float() | |
| if topk: | |
| qid_mask = qid_mask * self._get_topk_mask(qids, logits.squeeze(-1), labels) | |
| batch_size = logits.shape[0] | |
| labels = labels.unsqueeze(1) | |
| score_pos = logits.expand(-1, batch_size) | |
| score_neg = score_pos.transpose(0, 1) | |
| pos = labels.expand(-1, batch_size) | |
| neg = pos.transpose(0, 1) | |
| if pairdiff_mask is not None: | |
| margin = (pos - neg + pairdiff_mask) * qid_mask * margin_weight | |
| else: | |
| margin = (pos - neg) * qid_mask * margin_weight | |
| pair_mask = (margin > 1e-6).float() | |
| score_diff = score_pos - score_neg | |
| margin_diff = margin + torch.clamp(-score_diff, min=-10.0) | |
| loss = torch.relu(margin_diff) * pair_mask | |
| return torch.sum(loss) / (torch.sum(pair_mask) + 1e-5) | |
| def _get_topk_mask(self, qids, logits, labels): | |
| flatten_qids = qids.view(-1) | |
| flatten_logits = logits.view(-1) | |
| flatten_labels = labels.view(-1) | |
| unique_qids = torch.unique(flatten_qids) | |
| batch_size = qids.shape[0] | |
| position_mask = torch.ones(batch_size, dtype=torch.float32, device=logits.device) | |
| for uq in unique_qids: | |
| indices = (flatten_qids == uq).nonzero(as_tuple=True)[0] | |
| cur_labels = flatten_labels[indices] | |
| valid_count = (cur_labels >= 0).sum().item() | |
| k = math.ceil(valid_count * 0.3) | |
| if k == 0: | |
| continue | |
| cur_logits = flatten_logits[indices] | |
| topk_idx = indices[cur_logits.argsort(descending=True)[:k]] | |
| position_mask[topk_idx] = 2.0 | |
| return position_mask.unsqueeze(-1).expand(batch_size, batch_size).transpose(0, 1) | |
| def _compute_score(self, probs: torch.Tensor) -> torch.Tensor: | |
| weights = torch.tensor([-1.0, 1.0], device=probs.device) | |
| return (probs * weights).sum(dim=-1, keepdim=True) | |