File size: 216 Bytes
5000658
 
 
 
 
 
 
1
2
3
4
5
6
7
8
def ppl(logits, output_ids):
    """
    Calculate per-token perplexity.
    """
    nlls = -logits.log_softmax(dim=-1)
    ppls = nlls.gather(-1, output_ids.long().unsqueeze(-1))
    return ppls.mean().exp().item()