File size: 1,091 Bytes
7b2d0eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
from transformers import PreTrainedModel, LongformerConfig, LongformerModel
import torch.nn as nn
import torch
class ContrastiveModel(PreTrainedModel):
config_class = LongformerConfig
def __init__(self, config):
super().__init__(config)
self.bert = LongformerModel._from_config(config)
self.fc = nn.Linear(config.hidden_size, 1)
def compute_score(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
score = self.fc(pooled_output).squeeze()
score = torch.sigmoid(score)
return score
def score_pair(self, sentence_a, sentence_b, tokenizer, device, max_length=2048):
inputs = tokenizer(
sentence_a, sentence_b,
return_tensors="pt", padding="max_length", truncation=True,
max_length=max_length
).to(device)
with torch.no_grad():
score = self.compute_score(inputs["input_ids"], inputs["attention_mask"])
return score.item() |