| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
| def perplexity( |
| model: AutoModelForCausalLM, |
| tok: AutoTokenizer, |
| text: str, |
| max_input_length: int = None, |
| ): |
| """ |
| Computes perplexity of a piece of text, measured on a reference model. |
| Text is truncated to max_input_length tokens. |
| """ |
|
|
| inputs = tok( |
| [text], return_tensors="pt", max_length=max_input_length, truncation=True |
| ).to("cuda") |
|
|
| logits = torch.nn.functional.log_softmax(model(**inputs).logits, dim=2) |
| log_probs = torch.gather(logits[:, :-1, :], 2, inputs["input_ids"][:, 1:, None])[0] |
|
|
| |
| return torch.exp(-1 / inputs["input_ids"].size(1) * log_probs.sum()).item() |
|
|