File size: 2,686 Bytes
ddb7b62 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | import torch
import transformers
from transformers import AutoTokenizer, AutoModel
from numpy import ndarray
import numpy as np
from .similarity import subspace_johnson, subspace_bert_score, vanilla_bert_score
class MySimilarity:
def __init__(self, device='cpu', model_name_or_path='bert-base-uncased'):
# Set up model
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModel.from_pretrained(model_name_or_path)
self.model.eval()
self.model.to(device)
self.max_length = 128
def __call__(self, sentence1, sentence2, weight="L2"):
pass
def encode(self, sentence, return_numpy=False, batch_size=12):
single_sentence = False
if isinstance(sentence, str):
sentence = [sentence]
single_sentence = True
embedding_list = []
with torch.no_grad():
total_batch = len(sentence) // batch_size + (1 if len(sentence) % batch_size > 0 else 0)
for batch_id in range(total_batch):
inputs = self.tokenizer(
sentence[batch_id*batch_size:(batch_id+1)*batch_size],
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.model(**inputs, return_dict=True)
embeddings = outputs.last_hidden_state.cpu()
embedding_list.append(embeddings)
embeddings = torch.cat(embedding_list, 0)
if return_numpy and not isinstance(embeddings, ndarray):
return embeddings.numpy()
return embeddings
class SubspaceJohnsonSimilarity(MySimilarity):
def __call__(self, sentence1, sentence2, weight="L2"):
hidden_states1 = self.encode(sentence1)
hidden_states2 = self.encode(sentence2)
return subspace_johnson(hidden_states1, hidden_states2, weight)
class SubspaceBERTScore(MySimilarity):
def __call__(self, sentence1, sentence2, weight="L2"):
hidden_states1 = self.encode(sentence1)
hidden_states2 = self.encode(sentence2)
return subspace_bert_score(hidden_states1, hidden_states2, weight)
class VanillaBERTScore(MySimilarity):
def __call__(self, sentence1, sentence2, weight="L2"):
hidden_states1 = self.encode(sentence1)
hidden_states2 = self.encode(sentence2)
return vanilla_bert_score(hidden_states1, hidden_states2, weight) |