| """Cosine similarity smoke test on final GPT-2 LLM2Vec checkpoint.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
|
|
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer |
|
|
| from gpt2_llm2vec.models import get_model_class |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser(description="Print cosine similarities for a few sentence pairs.") |
| p.add_argument("--model-path", default="gpt2_llm2vec/checkpoints/gpt2_llm2vec_final") |
| p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") |
| return p.parse_args() |
|
|
|
|
| def mean_pool(last_hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
| mask = attention_mask.unsqueeze(-1).type_as(last_hidden) |
| summed = (last_hidden * mask).sum(dim=1) |
| lengths = mask.sum(dim=1).clamp(min=1e-9) |
| return summed / lengths |
|
|
|
|
| @torch.no_grad() |
| def encode(model, tokenizer, texts: list[str], device: str) -> torch.Tensor: |
| enc = tokenizer(texts, padding=True, truncation=True, max_length=256, return_tensors="pt") |
| enc = {k: v.to(device) for k, v in enc.items()} |
| out = model.transformer(**enc, return_dict=True) |
| emb = mean_pool(out.last_hidden_state, enc["attention_mask"]) |
| return F.normalize(emb, p=2, dim=1) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| tokenizer = AutoTokenizer.from_pretrained(args.model_path) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| model_cls = get_model_class("gpt2-large") |
| model = model_cls.from_pretrained(args.model_path) |
| model.eval() |
| model.to(args.device) |
|
|
| pairs = [ |
| ("The cat sits on the mat.", "A cat is resting on a rug."), |
| ("Python is a programming language.", "Coding in Python is popular."), |
| ("The stock market rose today.", "It is raining heavily outside."), |
| ("How do I sort a list in Python?", "Use sorted() or list.sort() in Python."), |
| ("Neural networks learn from data.", "Pizza tastes best when hot."), |
| ] |
| for a, b in pairs: |
| e = encode(model, tokenizer, [a, b], args.device) |
| sim = (e[0] * e[1]).sum().item() |
| print(f"cos_sim={sim:.4f}") |
| print(f" A: {a}") |
| print(f" B: {b}") |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|