Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from typing import Optional, Dict | |
| from transformers import ( | |
| PreTrainedModel, | |
| AutoModelForSequenceClassification | |
| ) | |
| from transformers.modeling_outputs import SequenceClassifierOutput | |
| from config import Arguments | |
| class Reranker(nn.Module): | |
| def __init__(self, hf_model: PreTrainedModel, args: Arguments): | |
| super().__init__() | |
| self.hf_model = hf_model | |
| self.args = args | |
| self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') | |
| self.kl_loss_fn = torch.nn.KLDivLoss(reduction="batchmean", log_target=True) | |
| def forward(self, batch: Dict[str, torch.Tensor]) -> SequenceClassifierOutput: | |
| input_batch_dict = {k: v for k, v in batch.items() if k != 'labels'} | |
| if self.args.rerank_forward_factor > 1: | |
| assert torch.sum(batch['labels']).long().item() == 0 | |
| assert all(len(v.shape) == 2 for v in input_batch_dict.values()) | |
| is_train = self.hf_model.training | |
| self.hf_model.eval() | |
| with torch.no_grad(): | |
| outputs: SequenceClassifierOutput = self.hf_model(**input_batch_dict, return_dict=True) | |
| outputs.logits = outputs.logits.view(-1, self.args.train_n_passages) | |
| # make sure the target passage is not masked out | |
| outputs.logits[:, 0].fill_(float('inf')) | |
| k = self.args.train_n_passages // self.args.rerank_forward_factor | |
| _, topk_indices = torch.topk(outputs.logits, k=k, dim=-1, largest=True) | |
| topk_indices += self.args.train_n_passages * torch.arange(0, topk_indices.shape[0], | |
| dtype=torch.long, | |
| device=topk_indices.device).unsqueeze(-1) | |
| topk_indices = topk_indices.view(-1) | |
| input_batch_dict = {k: v.index_select(dim=0, index=topk_indices) for k, v in input_batch_dict.items()} | |
| self.hf_model.train(is_train) | |
| n_psg_per_query = self.args.train_n_passages // self.args.rerank_forward_factor | |
| if self.args.rerank_use_rdrop and self.training: | |
| input_batch_dict = {k: torch.cat([v, v], dim=0) for k, v in input_batch_dict.items()} | |
| outputs: SequenceClassifierOutput = self.hf_model(**input_batch_dict, return_dict=True) | |
| if self.args.rerank_use_rdrop and self.training: | |
| logits = outputs.logits.view(2, -1, n_psg_per_query) | |
| outputs.logits = logits[0, :, :].contiguous() | |
| log_prob = torch.log_softmax(logits, dim=2) | |
| log_prob1, log_prob2 = log_prob[0, :, :], log_prob[1, :, :] | |
| rdrop_loss = 0.5 * (self.kl_loss_fn(log_prob1, log_prob2) + self.kl_loss_fn(log_prob2, log_prob1)) | |
| ce_loss = 0.5 * (self.cross_entropy(log_prob1, batch['labels']) | |
| + self.cross_entropy(log_prob2, batch['labels'])) | |
| outputs.loss = rdrop_loss + ce_loss | |
| else: | |
| outputs.logits = outputs.logits.view(-1, n_psg_per_query) | |
| loss = self.cross_entropy(outputs.logits, batch['labels']) | |
| outputs.loss = loss | |
| return outputs | |
| def from_pretrained(cls, all_args: Arguments, *args, **kwargs): | |
| hf_model = AutoModelForSequenceClassification.from_pretrained(*args, **kwargs) | |
| return cls(hf_model, all_args) | |
| def save_pretrained(self, output_dir: str): | |
| self.hf_model.save_pretrained(output_dir) | |
| class RerankerForInference(nn.Module): | |
| def __init__(self, hf_model: Optional[PreTrainedModel] = None): | |
| super().__init__() | |
| self.hf_model = hf_model | |
| self.hf_model.eval() | |
| def forward(self, batch) -> SequenceClassifierOutput: | |
| return self.hf_model(**batch) | |
| def from_pretrained(cls, pretrained_model_name_or_path: str): | |
| hf_model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path) | |
| return cls(hf_model) | |