from transformers import PreTrainedModel, LongformerConfig, LongformerModel import torch.nn as nn import torch class ContrastiveModel(PreTrainedModel): config_class = LongformerConfig def __init__(self, config): super().__init__(config) self.bert = LongformerModel._from_config(config) self.fc = nn.Linear(config.hidden_size, 1) def compute_score(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) pooled_output = outputs.pooler_output score = self.fc(pooled_output).squeeze() score = torch.sigmoid(score) return score def score_pair(self, sentence_a, sentence_b, tokenizer, device, max_length=2048): inputs = tokenizer( sentence_a, sentence_b, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length ).to(device) with torch.no_grad(): score = self.compute_score(inputs["input_ids"], inputs["attention_mask"]) return score.item()