File size: 4,349 Bytes
2997d61 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import numpy as np
import torch
from typing import List, Tuple
from stripedhyena.model import StripedHyena
from stripedhyena.tokenizer import CharLevelTokenizer
def prepare_batch(
seqs: List[str],
tokenizer: CharLevelTokenizer,
prepend_bos: bool = True,
device: str = 'cuda:0'
) -> Tuple[torch.Tensor, List[int]]:
"""
Takes in a list of sequences, tokenizes them, and puts them in a tensor batch.
If the sequences have differing lengths, then pad up to the maximum sequence length.
"""
seq_lengths = [ len(seq) for seq in seqs ]
max_seq_length = max(seq_lengths)
input_ids = []
for seq in seqs:
padding = [tokenizer.pad_id] * (max_seq_length - len(seq))
input_ids.append(
torch.tensor(
([tokenizer.eod_id] * int(prepend_bos)) + tokenizer.tokenize(seq) + padding,
dtype=torch.long,
).to(device).unsqueeze(0)
)
input_ids = torch.cat(input_ids, dim=0)
return input_ids, seq_lengths
def logits_to_logprobs(
logits: torch.Tensor,
input_ids: torch.Tensor,
trim_bos: bool = True,
) -> torch.Tensor:
"""
Takes in a tensor of logits of dimension (batch, length, vocab).
Computes the log-likelihoods using a softmax along the vocab dimension.
Uses the `input_ids` to index into the log-likelihoods and returns the likelihood
of the provided sequence at each position with dimension (batch, length).
"""
softmax_logprobs = torch.log_softmax(logits, dim=-1)
if trim_bos:
softmax_logprobs = softmax_logprobs[:, :-1] # Remove last prediction.
input_ids = input_ids[:, 1:] # Trim BOS added by tokenizer.
assert(softmax_logprobs.shape[1] == input_ids.shape[1])
logprobs = torch.gather(
softmax_logprobs, # Gather likelihoods...
2, # along the vocab dimension...
input_ids.unsqueeze(-1) # using the token ids to index.
).squeeze(-1)
return logprobs
def score_sequences(
seqs: List[str],
model: StripedHyena,
tokenizer: CharLevelTokenizer,
reduce_method: str = 'mean',
device: str = 'cuda:0',
) -> List[float]:
"""
Computes the model log-likelihood scores for sequences in `seqs`.
Uses `reduce_method` to take the mean or sum across the likelihoods at each
position (default: `'mean'`).
Returns a list of scalar scores corresponding to the reduced log-likelihoods for
each sequence.
"""
input_ids, seq_lengths = prepare_batch(seqs, tokenizer, device=device, prepend_bos=True)
assert(len(seq_lengths) == input_ids.shape[0])
with torch.inference_mode():
logits, _ = model(input_ids) # (batch, length, vocab)
logprobs = logits_to_logprobs(logits, input_ids, trim_bos=True)
logprobs = logprobs.float().cpu().numpy()
if reduce_method == 'mean':
reduce_func = np.mean
elif reduce_method == 'sum':
reduce_func = np.sum
else:
raise ValueError(f'Invalid reduce_method {reduce_method}')
return [
reduce_func(logprobs[idx][:seq_lengths[idx]])
for idx in range(len(seq_lengths))
]
def positional_entropies(
seqs: List[str],
model: StripedHyena,
tokenizer: CharLevelTokenizer,
device: str = 'cuda:0',
) -> List[np.array]:
"""
Computes the positional entropies for sequences in `seqs`.
Returns a list of arrays, where each array is the same length as the
corresponding sequence length. Each array contains the per-position entropy
across the vocab dimension.
"""
input_ids, seq_lengths = prepare_batch(seqs, tokenizer, device=device, prepend_bos=True)
assert(len(seq_lengths) == input_ids.shape[0])
with torch.inference_mode():
logits, _ = model(input_ids) # (batch, length, vocab)
# Tokenizer prepends BOS, remember to remove last prediction.
softmax_logprobs = torch.log_softmax(logits, dim=-1)[:, :-1]
entropies = -torch.sum(torch.exp(softmax_logprobs) * softmax_logprobs, dim=-1)
entropies = entropies.float().cpu().numpy()
sequence_entropies = [
entropies[idx][:seq_lengths[idx]] for idx in range(len(seq_lengths))
]
assert all(
len(seq) == len(entropy) for seq, entropy in zip(seqs, sequence_entropies)
)
return sequence_entropies
|