"""Cosine similarity smoke test on final GPT-2 LLM2Vec checkpoint.""" from __future__ import annotations import argparse import torch import torch.nn.functional as F from transformers import AutoTokenizer from gpt2_llm2vec.models import get_model_class def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="Print cosine similarities for a few sentence pairs.") p.add_argument("--model-path", default="gpt2_llm2vec/checkpoints/gpt2_llm2vec_final") p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") return p.parse_args() def mean_pool(last_hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: mask = attention_mask.unsqueeze(-1).type_as(last_hidden) summed = (last_hidden * mask).sum(dim=1) lengths = mask.sum(dim=1).clamp(min=1e-9) return summed / lengths @torch.no_grad() def encode(model, tokenizer, texts: list[str], device: str) -> torch.Tensor: enc = tokenizer(texts, padding=True, truncation=True, max_length=256, return_tensors="pt") enc = {k: v.to(device) for k, v in enc.items()} out = model.transformer(**enc, return_dict=True) emb = mean_pool(out.last_hidden_state, enc["attention_mask"]) return F.normalize(emb, p=2, dim=1) def main() -> None: args = parse_args() tokenizer = AutoTokenizer.from_pretrained(args.model_path) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model_cls = get_model_class("gpt2-large") model = model_cls.from_pretrained(args.model_path) model.eval() model.to(args.device) pairs = [ ("The cat sits on the mat.", "A cat is resting on a rug."), ("Python is a programming language.", "Coding in Python is popular."), ("The stock market rose today.", "It is raining heavily outside."), ("How do I sort a list in Python?", "Use sorted() or list.sort() in Python."), ("Neural networks learn from data.", "Pizza tastes best when hot."), ] for a, b in pairs: e = encode(model, tokenizer, [a, b], args.device) sim = (e[0] * e[1]).sum().item() print(f"cos_sim={sim:.4f}") print(f" A: {a}") print(f" B: {b}") print() if __name__ == "__main__": main()