File size: 4,092 Bytes
d0a3e2d
 
 
 
 
 
 
 
 
 
 
b53df79
d0a3e2d
27cdddc
 
 
 
 
 
 
d0a3e2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd4f3a8
 
d0a3e2d
b53df79
 
 
27cdddc
 
b53df79
 
 
8021b9c
f663b1c
 
 
 
27cdddc
 
 
 
f663b1c
 
 
27cdddc
f663b1c
 
 
 
27cdddc
 
 
 
 
 
 
 
f663b1c
27cdddc
 
 
 
 
 
 
 
d0a3e2d
27cdddc
edbcb21
d0a3e2d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
#                                                           #
#   This file was created by: Alberto Palomo Alonso         #
# Universidad de Alcalá - Escuela Politécnica Superior      #
#                                                           #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import torch
import tqdm
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
from scipy.stats import spearmanr

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def zero_shot_proposed(
    model_repo: str = None,
    data_repo: str = None,
    batch_size: int = 32,
    device: torch.device = torch.device('cpu')
):
    """

    """
    # Pathing:
    if model_repo is None:
        model_repo = input("Enter the model repository path or identifier: ")
    if data_repo is None:
        data_repo = input("Enter the dataset repository path or identifier: ")

    # Loading:
    tokenizer = AutoTokenizer.from_pretrained(model_repo, trust_remote_code=True)
    model = AutoModel.from_pretrained(model_repo, trust_remote_code=True)
    dataset = load_dataset(data_repo)
    model.eval()
    model.to(device)

    y = list()
    y_hat = list()

    with torch.no_grad():
        for batch in tqdm.tqdm(dataset['test'].batch(batch_size)):
            if not hasattr(model, 'get_sentence_embedding'):
                inputs_1 = tokenizer(batch['sentence1'], return_tensors="pt", padding=True, truncation=True, max_length=382)
                inputs_2 = tokenizer(batch['sentence2'], return_tensors="pt", padding=True, truncation=True, max_length=382)
                inputs_1 = {k: v.to(device) for k, v in inputs_1.items()}
                inputs_2 = {k: v.to(device) for k, v in inputs_2.items()}
                embeddings_1 = model(**inputs_1)
                embeddings_2 = model(**inputs_2)
                embeddings_1 = mean_pooling(embeddings_1, inputs_1['attention_mask'])
                embeddings_2 = mean_pooling(embeddings_2, inputs_2['attention_mask'])
                embeddings_1 = torch.nn.functional.normalize(embeddings_1, p=2, dim=-1)
                embeddings_2 = torch.nn.functional.normalize(embeddings_2, p=2, dim=-1)
                sim = (embeddings_1 * embeddings_2).sum(dim=-1)
            else:
                inputs_1 = tokenizer(batch['sentence1'], return_tensors="pt", padding='max_length', truncation=True, max_length=382)
                inputs_2 = tokenizer(batch['sentence2'], return_tensors="pt", padding='max_length', truncation=True, max_length=382)
                inputs_1 = {k: v.to(device) for k, v in inputs_1.items()}
                inputs_2 = {k: v.to(device) for k, v in inputs_2.items()}
                embeddings_1 = model.get_sentence_embedding(
                    input_ids=inputs_1["input_ids"],
                    attention_mask=inputs_1["attention_mask"]
                )
                embeddings_2 = model.get_sentence_embedding(
                    input_ids=inputs_2["input_ids"],
                    attention_mask=inputs_2["attention_mask"]
                )
                sim = model.similarity(embeddings_1, embeddings_2)

            y.extend(batch['label'])
            y_hat.extend(sim.cpu().numpy().tolist())

    # Benchmarking:
    for _y, _yh in zip(y, y_hat):
        print(f"Gold: {_y:.4f} - Predicted: {_yh:.4f}")
    rho, _ = spearmanr(y, y_hat)

    print(f"Average Spearman correlation: {rho:.4f}")
    return rho
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
#                        END OF FILE                        #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #