File size: 1,091 Bytes
7b2d0eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from transformers import PreTrainedModel, LongformerConfig, LongformerModel
import torch.nn as nn
import torch

class ContrastiveModel(PreTrainedModel):
    
    config_class = LongformerConfig

    def __init__(self, config):
        super().__init__(config)
        self.bert = LongformerModel._from_config(config)
        self.fc = nn.Linear(config.hidden_size, 1)

    def compute_score(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output  
        score = self.fc(pooled_output).squeeze()  
        score = torch.sigmoid(score)  
        return score
    
    def score_pair(self, sentence_a, sentence_b, tokenizer, device, max_length=2048):

        inputs = tokenizer(
            sentence_a, sentence_b,
            return_tensors="pt", padding="max_length", truncation=True,
            max_length=max_length
        ).to(device)

        with torch.no_grad():
            score = self.compute_score(inputs["input_ids"], inputs["attention_mask"])
        return score.item()