HUMG Cross-Encoder (Reranker)

Mô hình Cross-Encoder được fine-tune từ vinai/phobert-base để xếp hạng lại (rerank) các cặp câu hỏi - đoạn văn bản.

Kiến trúc

  • Base model: vinai/phobert-base
  • Classifier: Mean pooling → Dropout → Linear(768 → 1)
  • Loss: BCEWithLogitsLoss
  • Max length: 512

Cách sử dụng

import torch
from transformers import AutoModel, AutoTokenizer
import torch.nn as nn

class CrossEncoderModel(nn.Module):
    def __init__(self, backbone_name, hidden_dropout=0.1):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(backbone_name)
        hidden_size = self.backbone.config.hidden_size
        self.dropout = nn.Dropout(hidden_dropout)
        self.classifier = nn.Linear(hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        mask = attention_mask.unsqueeze(-1).float()
        pooled = (out.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
        x = self.dropout(pooled)
        logit = self.classifier(x).squeeze(-1)
        return logit

tokenizer = AutoTokenizer.from_pretrained("mudotet/humg-cross-encoder", use_fast=False)
model = CrossEncoderModel("vinai/phobert-base")
state = torch.load("model.pt", map_location="cpu")  # download from this repo
model.load_state_dict(state)
model.eval()

# Score a question-passage pair
inputs = tokenizer("câu hỏi", "đoạn văn bản", return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
    logit = model(inputs["input_ids"], inputs["attention_mask"])
    score = torch.sigmoid(logit).item()
print(f"Relevance score: {score:.4f}")
Downloads last month
124
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support