File size: 2,686 Bytes
ddb7b62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import transformers
from transformers import AutoTokenizer, AutoModel
from numpy import ndarray
import numpy as np
from .similarity import subspace_johnson, subspace_bert_score, vanilla_bert_score


class MySimilarity:
    def __init__(self, device='cpu', model_name_or_path='bert-base-uncased'):
        # Set up model
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        self.model = AutoModel.from_pretrained(model_name_or_path)
        self.model.eval()
        self.model.to(device)
        self.max_length = 128

    def __call__(self, sentence1, sentence2, weight="L2"):
        pass
    
    
    def encode(self, sentence, return_numpy=False, batch_size=12):
        
        single_sentence = False
        if isinstance(sentence, str):
            sentence = [sentence]
            single_sentence = True

        embedding_list = [] 
        with torch.no_grad():
            total_batch = len(sentence) // batch_size + (1 if len(sentence) % batch_size > 0 else 0)
            for batch_id in range(total_batch):
                inputs = self.tokenizer(
                    sentence[batch_id*batch_size:(batch_id+1)*batch_size], 
                    padding=True, 
                    truncation=True, 
                    max_length=self.max_length, 
                    return_tensors="pt"
                )
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                outputs = self.model(**inputs, return_dict=True)

                embeddings = outputs.last_hidden_state.cpu()
                embedding_list.append(embeddings)
                
        embeddings = torch.cat(embedding_list, 0)
        
        if return_numpy and not isinstance(embeddings, ndarray):
            return embeddings.numpy()
        return embeddings


class SubspaceJohnsonSimilarity(MySimilarity):
    def __call__(self, sentence1, sentence2, weight="L2"):
        hidden_states1 = self.encode(sentence1)
        hidden_states2 = self.encode(sentence2)
        return subspace_johnson(hidden_states1, hidden_states2, weight)
    
        
class SubspaceBERTScore(MySimilarity):
    def __call__(self, sentence1, sentence2, weight="L2"):
        hidden_states1 = self.encode(sentence1)
        hidden_states2 = self.encode(sentence2)
        return subspace_bert_score(hidden_states1, hidden_states2, weight)
    
    
class VanillaBERTScore(MySimilarity):
    def __call__(self, sentence1, sentence2, weight="L2"):
        hidden_states1 = self.encode(sentence1)
        hidden_states2 = self.encode(sentence2)
        return vanilla_bert_score(hidden_states1, hidden_states2, weight)