File size: 2,289 Bytes
824cb05 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | """Cosine similarity smoke test on final GPT-2 LLM2Vec checkpoint."""
from __future__ import annotations
import argparse
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from gpt2_llm2vec.models import get_model_class
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Print cosine similarities for a few sentence pairs.")
p.add_argument("--model-path", default="gpt2_llm2vec/checkpoints/gpt2_llm2vec_final")
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
return p.parse_args()
def mean_pool(last_hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
mask = attention_mask.unsqueeze(-1).type_as(last_hidden)
summed = (last_hidden * mask).sum(dim=1)
lengths = mask.sum(dim=1).clamp(min=1e-9)
return summed / lengths
@torch.no_grad()
def encode(model, tokenizer, texts: list[str], device: str) -> torch.Tensor:
enc = tokenizer(texts, padding=True, truncation=True, max_length=256, return_tensors="pt")
enc = {k: v.to(device) for k, v in enc.items()}
out = model.transformer(**enc, return_dict=True)
emb = mean_pool(out.last_hidden_state, enc["attention_mask"])
return F.normalize(emb, p=2, dim=1)
def main() -> None:
args = parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model_cls = get_model_class("gpt2-large")
model = model_cls.from_pretrained(args.model_path)
model.eval()
model.to(args.device)
pairs = [
("The cat sits on the mat.", "A cat is resting on a rug."),
("Python is a programming language.", "Coding in Python is popular."),
("The stock market rose today.", "It is raining heavily outside."),
("How do I sort a list in Python?", "Use sorted() or list.sort() in Python."),
("Neural networks learn from data.", "Pizza tastes best when hot."),
]
for a, b in pairs:
e = encode(model, tokenizer, [a, b], args.device)
sim = (e[0] * e[1]).sum().item()
print(f"cos_sim={sim:.4f}")
print(f" A: {a}")
print(f" B: {b}")
print()
if __name__ == "__main__":
main()
|