HUMG Cross-Encoder (Reranker)
Mô hình Cross-Encoder được fine-tune từ vinai/phobert-base
để xếp hạng lại (rerank) các cặp câu hỏi - đoạn văn bản.
Kiến trúc
- Base model:
vinai/phobert-base - Classifier: Mean pooling → Dropout → Linear(768 → 1)
- Loss: BCEWithLogitsLoss
- Max length: 512
Cách sử dụng
import torch
from transformers import AutoModel, AutoTokenizer
import torch.nn as nn
class CrossEncoderModel(nn.Module):
def __init__(self, backbone_name, hidden_dropout=0.1):
super().__init__()
self.backbone = AutoModel.from_pretrained(backbone_name)
hidden_size = self.backbone.config.hidden_size
self.dropout = nn.Dropout(hidden_dropout)
self.classifier = nn.Linear(hidden_size, 1)
def forward(self, input_ids, attention_mask):
out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
mask = attention_mask.unsqueeze(-1).float()
pooled = (out.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
x = self.dropout(pooled)
logit = self.classifier(x).squeeze(-1)
return logit
tokenizer = AutoTokenizer.from_pretrained("mudotet/humg-cross-encoder", use_fast=False)
model = CrossEncoderModel("vinai/phobert-base")
state = torch.load("model.pt", map_location="cpu") # download from this repo
model.load_state_dict(state)
model.eval()
# Score a question-passage pair
inputs = tokenizer("câu hỏi", "đoạn văn bản", return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
logit = model(inputs["input_ids"], inputs["attention_mask"])
score = torch.sigmoid(logit).item()
print(f"Relevance score: {score:.4f}")
- Downloads last month
- 124