File size: 2,289 Bytes
824cb05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Cosine similarity smoke test on final GPT-2 LLM2Vec checkpoint."""

from __future__ import annotations

import argparse

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer

from gpt2_llm2vec.models import get_model_class


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Print cosine similarities for a few sentence pairs.")
    p.add_argument("--model-path", default="gpt2_llm2vec/checkpoints/gpt2_llm2vec_final")
    p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    return p.parse_args()


def mean_pool(last_hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    mask = attention_mask.unsqueeze(-1).type_as(last_hidden)
    summed = (last_hidden * mask).sum(dim=1)
    lengths = mask.sum(dim=1).clamp(min=1e-9)
    return summed / lengths


@torch.no_grad()
def encode(model, tokenizer, texts: list[str], device: str) -> torch.Tensor:
    enc = tokenizer(texts, padding=True, truncation=True, max_length=256, return_tensors="pt")
    enc = {k: v.to(device) for k, v in enc.items()}
    out = model.transformer(**enc, return_dict=True)
    emb = mean_pool(out.last_hidden_state, enc["attention_mask"])
    return F.normalize(emb, p=2, dim=1)


def main() -> None:
    args = parse_args()
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model_cls = get_model_class("gpt2-large")
    model = model_cls.from_pretrained(args.model_path)
    model.eval()
    model.to(args.device)

    pairs = [
        ("The cat sits on the mat.", "A cat is resting on a rug."),
        ("Python is a programming language.", "Coding in Python is popular."),
        ("The stock market rose today.", "It is raining heavily outside."),
        ("How do I sort a list in Python?", "Use sorted() or list.sort() in Python."),
        ("Neural networks learn from data.", "Pizza tastes best when hot."),
    ]
    for a, b in pairs:
        e = encode(model, tokenizer, [a, b], args.device)
        sim = (e[0] * e[1]).sum().item()
        print(f"cos_sim={sim:.4f}")
        print(f"  A: {a}")
        print(f"  B: {b}")
        print()


if __name__ == "__main__":
    main()