nas / PFMBench /src /data /esm /utils /generation.py
yuccaaa's picture
Add files using upload-large-folder tool
9627ce0 verified
import os
from typing import Any, Callable, Sequence
from warnings import warn
import attr
import torch
from tqdm import tqdm
from src.data.esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,
ESMProteinError,
ESMProteinTensor,
ForwardAndSampleOutput,
ForwardTrackData,
GenerationConfig,
LogitsConfig,
LogitsOutput,
SamplingConfig,
SamplingTrackConfig,
)
from src.data.esm.tokenization import (
EsmTokenizerBase,
TokenizerCollectionProtocol,
)
from src.data.esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from src.data.esm.utils.constants import esm3 as C
from src.data.esm.utils.misc import stack_variable_length_tensors
from src.data.esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY
from src.data.esm.utils.sampling import (
_BatchedESMProteinTensor,
get_sampling_mask,
sample_function_logits,
sample_logits,
sample_residue_annotation_logits,
sample_sasa_logits,
)
def _trim_sequence_tensor_dataclass(o: Any, sequence_len: int):
"""Trim tensors on the sequence dimension.
This util assume that input tensor class has batch dimension.
"""
assert attr.has(o.__class__)
sliced = {}
for k, v in attr.asdict(o, recurse=False).items():
if v is None:
sliced[k] = None
elif isinstance(v, torch.Tensor):
# Trim padding.
sliced[k] = v[:, :sequence_len]
elif isinstance(v, tuple) and all(isinstance(t, torch.Tensor) for t in v):
# Trim padding for a list of tensors
sliced[k] = [t[:, :sequence_len] for t in v]
elif attr.has(v.__class__):
# Recursively slice the child attribute.
sliced[k] = _trim_sequence_tensor_dataclass(v, sequence_len)
else:
# Otherwise, simply copy the entire data bit over.
sliced[k] = v
return attr.evolve(o, **sliced)
def _slice_tensor_dataclass(o: Any, i: int, keep_dim: bool = False) -> Any:
"""Take a slice out of any attr defined Tensor objects along the batch dimension.
Args:
o: input tensor object to be sliced.
i: index of the row to be sliced.
keep_dim: whether to keep the batch dim after slicing.
For example, given a tensor of shape (5, 8), if keep_dim is True,
return a sliced tensor of shape (1, 8). Return a tensor of shape
(8,) instead if keep_dim is False. The default is False.
"""
assert attr.has(o.__class__)
sliced = {}
for k, v in attr.asdict(o, recurse=False).items():
if v is None:
sliced[k] = None
elif isinstance(v, torch.Tensor):
# Select the i-th row of each tensor.
row = v.select(0, i)
if keep_dim:
row = row.unsqueeze(0)
sliced[k] = row
elif attr.has(v.__class__):
# Recursively slice the child attribute.
sliced[k] = _slice_tensor_dataclass(v, i, keep_dim)
else:
# Otherwise, simply copy the entire data bit over.
sliced[k] = v
return attr.evolve(o, **sliced)
def iterative_sampling_raw(
client: ESM3InferenceClient,
proteins: list[ESMProtein],
configs: list[GenerationConfig],
) -> list[ESMProtein | ESMProteinError]:
# Keep structure tokens
input_tokens = [client.encode(protein) for protein in proteins]
output_tokens_list = client.batch_generate(input_tokens, configs)
raw_proteins: list[ESMProtein | ESMProteinError] = []
for output_tokens in output_tokens_list:
if isinstance(output_tokens, ESMProteinTensor):
raw_proteins.append(client.decode(output_tokens))
elif isinstance(output_tokens, ESMProteinError):
raw_proteins.append(output_tokens)
else:
raise ValueError(f"Unknown output type {type(output_tokens)}")
for input_protein, raw_protein, config in zip(proteins, raw_proteins, configs):
if isinstance(raw_protein, ESMProteinError):
# If this generation errored out.
continue
if config.track not in ["function", "residue_annotations"]:
# Function and residue annotation encoding/decoding is lossy
# There is no guarantee that decoding encoded tokens will yield the same input
raw_protein.function_annotations = input_protein.function_annotations
return raw_proteins
def _make_masked_inputs(
track: str, sequence_length: int, tokenizers: TokenizerCollectionProtocol
):
get_tokenizer: Callable[[str], EsmTokenizerBase] = lambda s: getattr(tokenizers, s)
has_tokenizer: Callable[[str], bool] = lambda s: hasattr(tokenizers, s)
if track == "coordinates":
dims = (sequence_length, 3, 3)
elif track == "confidence":
dims = (sequence_length,)
elif track == "attention_mask":
dims = (sequence_length,)
elif track == "function":
dims = (sequence_length, tokenizers.function.depth)
elif track == "residue_annotations":
dims = (sequence_length, C.MAX_RESIDUE_ANNOTATIONS)
else:
dims = (sequence_length,)
if track == "coordinates":
masked_tokens = torch.full(dims, torch.inf, dtype=torch.float)
elif track == "confidence":
# All-mask dummy input for confidence track.
masked_tokens = torch.full(dims, 0.0)
elif track == "attention_mask":
masked_tokens = torch.full(dims, 1, dtype=torch.bool)
elif has_tokenizer(track):
masked_tokens = torch.full(
dims, get_tokenizer(track).mask_token_id, dtype=torch.long
)
masked_tokens[0] = get_tokenizer(track).bos_token_id
masked_tokens[-1] = get_tokenizer(track).eos_token_id
else:
# Does not know how to create the dummy all masked input.
return None
return masked_tokens
def _stack_protein_tensors(
input_tokens: list[ESMProteinTensor],
sequence_lengths: list[int],
tokenizers: TokenizerCollectionProtocol,
device: str | torch.device,
) -> _BatchedESMProteinTensor:
o = _BatchedESMProteinTensor()
def _maybe_mock_input(fn, t, l):
if t is not None:
return t
# Try create dummy masked input for this prompt.
t = _make_masked_inputs(fn, l, tokenizers)
if t is not None:
t = t.to(device)
return t
def _stack_field(fn: str):
tensors = [getattr(tokens, fn) for tokens in input_tokens]
# Create all mask mock inputs for any tensors that are None.
tensors = [
_maybe_mock_input(fn, t, l) for t, l in zip(tensors, sequence_lengths)
]
# Handle any track that has all None as the input.
# We can't meaningfully stack tensors in this case, so simply batched
# them as None in _BatchedESMProteinTensor.
if all([t is None for t in tensors]):
setattr(o, fn, None)
return
if fn == "coordinates":
mask_token_id = torch.inf
else:
mask_token_id = getattr(tokenizers, fn).pad_token_id
setattr(
o,
fn,
stack_variable_length_tensors(
sequences=tensors, # type: ignore
constant_value=mask_token_id,
),
)
for f in attr.fields(ESMProteinTensor):
# We do not batch potential_sequence_of_concern field.
if f.name == "potential_sequence_of_concern":
continue
_stack_field(f.name)
return o
def _get_masked_positions(
track: str, tokens: torch.Tensor, mask_token_id: int
) -> torch.Tensor:
if track == "function":
mask = torch.all(tokens == mask_token_id, dim=-1).to(tokens.device)
else:
mask = tokens == mask_token_id
# Should not sample BOS and EOS positions.
mask[..., 0] = False
mask[..., -1] = False
return mask
def _get_iterative_sampling_mask_for_prompt_and_step(
cur_sampled: _BatchedESMProteinTensor,
sequence_lengths: torch.Tensor,
total_to_sample: torch.Tensor,
step: int,
entropy: ForwardTrackData,
config: GenerationConfig,
tokenizers: TokenizerCollectionProtocol,
) -> torch.Tensor:
"""Get sampling mask based on forward output and config.
Returns:
Sampling mask and num of positions sampled.
"""
track_to_sample = config.track
tokens = getattr(cur_sampled, track_to_sample)
device = tokens.device
shape = tokens.shape
B, L = shape[0], shape[1]
# TODO: figure out why we want this function to work with
# _BatchedESMProteinTensor in the first place. Logics below
# don't really work for batched tensors.
assert B == 1
sampling_mask = torch.ones((B, L), dtype=torch.bool, device=device)
sampling_mask[:, 0] = False # BOS
# EOS and all padding tokens.
sampling_mask &= (
torch.arange(L).repeat(B, 1) < (sequence_lengths - 1).unsqueeze(-1)
).to(device)
is_mask = _get_masked_positions(
track_to_sample, tokens, getattr(tokenizers, track_to_sample).mask_token_id
)
if not is_mask.any().item():
raise ValueError(f"Cannot sample {config.track} when input has no masks.")
sampling_mask = sampling_mask & is_mask
# Initialize schedule and masks
decoding_schedule = NOISE_SCHEDULE_REGISTRY[config.schedule]
# Calculate number of tokens to sample
still_masked = torch.sum(sampling_mask).int()
perc_masked_after_this_step = decoding_schedule(
torch.tensor((step + 1) / config.num_steps)
)
num_tokens_masked_after_this_step = (
# To avoid rounding errors, add a small epsilon.
# NOTE: Tensor.round does not cast to int,
# so it actually leads to rounding down.
# e.g. tensor(67.0000).int() = 66
perc_masked_after_this_step * total_to_sample + 0.1
).int()
num_to_sample = still_masked - num_tokens_masked_after_this_step
if config.strategy == "entropy":
track_entropy: torch.Tensor = getattr(entropy, track_to_sample).to(
device
) # (B, L) or (B, L, D)
if track_to_sample == "function":
track_entropy = track_entropy.sum(-1) # (B, L, D) -> (B, L)
track_entropy = track_entropy.masked_fill(
~sampling_mask, torch.finfo(track_entropy.dtype).max
)
_, indices = track_entropy.topk(num_to_sample, dim=-1, largest=False)
is_top_k = torch.zeros((B, L), dtype=torch.bool, device=device).scatter(
1, indices, True
)
where_to_sample = sampling_mask & is_top_k
elif config.strategy == "random":
# Skip B since we know there is only 1 prompt here.
_, masked_indices = sampling_mask.nonzero(as_tuple=True)
# Random shuffle the masked indices then select the first num_to_sample.
rnd_indices = masked_indices[torch.randperm(len(masked_indices))][
:num_to_sample
]
rnd_mask = torch.zeros_like(sampling_mask)
rnd_mask[:, rnd_indices] = True
where_to_sample = sampling_mask & rnd_mask
if track_to_sample == "function":
where_to_sample = where_to_sample.unsqueeze(-1).expand(
B, L, tokenizers.function.depth
) # (B, L) -> (B, L, D)
return where_to_sample
def _get_non_special_tokens(
protein: ESMProteinTensor, tokenizers: TokenizerCollectionProtocol
) -> int:
if protein.sequence is None:
# There is no sequence to infer the number of tokens to decode.
# So we assume the entire sequence minus bos and eos are for decoding.
return len(protein) - 2
mask = torch.ones_like(protein.sequence)
for special_token in tokenizers.sequence.special_token_ids:
if special_token == tokenizers.sequence.mask_token_id:
continue # MASK tokens need to be sampled.
mask[protein.sequence == special_token] = 0
return int(torch.sum(mask).item())
def _get_annealed_temperature(step: int, num_steps: int, initial_temperature: float):
step_ratio = step / max(1, (num_steps - 1))
return max(initial_temperature - step_ratio, 0.001) ** 2
def iterative_sampling_tokens(
client: ESM3InferenceClient,
input_tokens: list[ESMProteinTensor],
configs: list[GenerationConfig],
tokenizers: TokenizerCollectionProtocol,
) -> Sequence[ESMProteinTensor | ESMProteinError]:
devices = set([t.device for t in input_tokens])
if len(devices) > 1:
raise AttributeError(f"Input tokens on multiple devices {devices}")
sampled_tokens = [attr.evolve(tokens) for tokens in input_tokens]
# Clear structure tokens if user would like to condition only on coordinates.
for tokens, config in zip(sampled_tokens, configs):
if config.condition_on_coordinates_only and tokens.coordinates is not None:
tokens.structure = None
# Total sequence lengths.
sequence_lengths = [len(tokens) for tokens in sampled_tokens]
# Figure out the number of tokens to be sampled for each prompt.
total_to_sample = []
for protein, config in zip(sampled_tokens, configs):
track = config.track
if getattr(protein, track) is None:
# We need to sample the entire track.
num_sampling_steps = _get_non_special_tokens(protein, tokenizers)
else:
masked = _get_masked_positions(
track, getattr(protein, track), getattr(tokenizers, track).mask_token_id
)
num_sampling_steps = torch.sum(masked).item()
total_to_sample.append(num_sampling_steps)
# Users might over-specify the number of sampling steps for a given prompt
# TODO: Give a warning about mismatched num_steps and number of masks.
if (num_sampling_steps > 0) and (num_sampling_steps < config.num_steps):
config.num_steps = int(num_sampling_steps)
# Different prompts may ask for different number of decoding steps.
# For now, we simply run the max number of steps.
# TODO: return completed proteins as soon as they are finished sampling.
max_num_steps = max([config.num_steps for config in configs])
# Now stack the list to make a single batched ESMProteinTensor.
batched_tokens = _stack_protein_tensors(
sampled_tokens, sequence_lengths, tokenizers, devices.pop()
)
# Remember sampled prompts that has somehow errored out.
errors: dict[int, ESMProteinError] = {}
# Decode
disable_tqdm = bool(os.environ.get("DISABLE_ITERATIVE_SAMPLING_TQDM", False))
for t in tqdm(range(max_num_steps), disable=disable_tqdm):
forward_out = _batch_forward(client, batched_tokens)
# Sample each prompt individually, since their configuration may
# be very different.
# TODO: downstream utils work with batch dimsension.
# Group by sampling configurations and sample those prompts together.
for i, config in enumerate(configs): # B
if i in errors:
# This prompts has errored out in previous steps.
# Skip.
continue
if config.track in ["coordinates", "residue_annotations"]:
errors[i] = ESMProteinError(
error_code=500,
error_msg=f"Iterative sampling {config.track} is not supported.",
)
continue
if t >= config.num_steps:
# Done sampling for this row.
continue
per_prompt_cur_sampled = _BatchedESMProteinTensor.from_protein_tensor(
batched_tokens.slice(i)
)
per_prompt_forward_out: LogitsOutput = _slice_tensor_dataclass(
forward_out, i, keep_dim=True
)
# Trim logits to proper sequence length for this prompt.
per_prompt_forward_out = _trim_sequence_tensor_dataclass(
per_prompt_forward_out,
# Note(jungong) : we can not smiply use sequence_lenths[i] here,
# what we want is for the sequence length of the logits to match
# that of the prompt, which may or may not be padded, depending on
# whether the padding was done locally with the open source model
# (where per_prompt_cur_sampled is already padded) or by
# BatchedESM3ModelRunner (where per_prompt_cur_sampled is not padded).
len(per_prompt_cur_sampled),
)
# Handle temperature annealing, since _sample_per_prompt() doesn't have
# the concept of decoding steps.
if config.temperature_annealing:
temperature = _get_annealed_temperature(
t, config.num_steps, config.temperature
)
else:
temperature = config.temperature
track_sample_config = SamplingTrackConfig()
track_sample_config.invalid_ids = config.invalid_ids
track_sample_config.temperature = temperature
track_sample_config.top_p = config.top_p
sampling_config = SamplingConfig(**{config.track: track_sample_config}) # type: ignore
# Sampling has to be done per-prompt, since sampling configs
# are likely be different for different prompts.
per_prompt_forward_and_sample_output = _sample_per_prompt(
per_prompt_cur_sampled,
per_prompt_forward_out,
sampling_config,
tokenizers,
decode_sasa_tokens=False,
)
# All positions sampled after _sample_per_prompt() above.
# (B, L) & (B, L, D)
per_prompt_new_sampled = per_prompt_forward_and_sample_output.protein_tensor
# Find the positions we should sample this round.
assert per_prompt_forward_and_sample_output.entropy is not None
try:
where_to_sample = _get_iterative_sampling_mask_for_prompt_and_step(
per_prompt_cur_sampled,
torch.tensor(sequence_lengths[i]),
torch.tensor(total_to_sample[i]),
t,
per_prompt_forward_and_sample_output.entropy,
config,
tokenizers,
)
except ValueError as e:
errors[i] = ESMProteinError(error_code=500, error_msg=str(e))
continue
where_to_sample.to(input_tokens[0].device)
old_track_samples = getattr(per_prompt_cur_sampled, config.track)
new_track_samples = getattr(per_prompt_new_sampled, config.track)
# Iterative sampling by picking the tokens sampled this round
# from new_track_samples to old_track_samples.
new_track_samples = torch.where(
where_to_sample, new_track_samples, old_track_samples
)
# Update the corresponding row with new data.
getattr(batched_tokens, config.track)[i, ...] = new_track_samples[0]
# Un-pack to a list of single ProteinTypes.
output_tokens = [
batched_tokens.slice(i, sequence_len=sequence_lengths[i])
if i not in errors
else errors[i]
for i in range(len(input_tokens))
]
# Do not update tracks that were not sampled (e.g. keep None instead of masks)
for inputs, outputs, config in zip(input_tokens, output_tokens, configs):
if isinstance(outputs, ESMProteinError):
continue
# First restore coordinates field.
# We know coordinates can never be iteratively sampled.
setattr(outputs, "coordinates", getattr(inputs, "coordinates"))
# Maybe restore all the other fields.
for f in attr.fields(SamplingConfig):
if "embedding" in f.name or f.name == "return_hidden_states":
continue
if f.name != config.track:
setattr(outputs, f.name, getattr(inputs, f.name))
return output_tokens
def _batch_forward(client: ESM3InferenceClient, protein: _BatchedESMProteinTensor):
# Forward pass
return client.logits(
protein,
LogitsConfig(
sequence=True,
structure=True,
secondary_structure=True,
sasa=True,
function=True,
residue_annotations=True,
return_embeddings=True,
),
)
def _sample_per_prompt(
protein: _BatchedESMProteinTensor,
logits_output: LogitsOutput,
sampling_config: SamplingConfig,
tokenizers: TokenizerCollectionProtocol,
decode_sasa_tokens: bool = True,
mask_logits_of_invalid_ids: bool = True,
) -> ForwardAndSampleOutput:
assert logits_output.logits is not None
def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None:
return x.clone() if x is not None else None
# Sampling
tokens_dir = {}
track_sampling_metadata_dir: dict[str, dict | None] = {}
integer_sampling_tracks = ["sequence", "structure", "secondary_structure"]
if not decode_sasa_tokens:
integer_sampling_tracks.append("sasa")
for track in integer_sampling_tracks:
config = getattr(sampling_config, track)
if config is None:
tokens_dir[track] = maybe_clone(getattr(protein, track))
continue
tokenizer = getattr(tokenizers, track)
valid_ids = (
set(tokenizer.all_token_ids)
- set(tokenizer.special_token_ids)
- set(config.invalid_ids)
)
sampling_metadata = _sample_track(
logits=getattr(logits_output.logits, track),
tokens=getattr(protein, track),
sampling_track_config=config,
mask_idx=getattr(tokenizers, track).mask_token_id,
valid_ids=list(valid_ids),
mask_logits_of_invalid_ids=mask_logits_of_invalid_ids,
)
tokens_dir[track] = sampling_metadata.pop("sampled_tokens") # (L,)
track_sampling_metadata_dir[track] = sampling_metadata
# Sample SASA seperately (if needed)
if decode_sasa_tokens:
config = getattr(sampling_config, "sasa")
track_sampling_metadata_dir["sasa"] = None
if config is None:
tokens_dir["sasa"] = maybe_clone(getattr(protein, "sasa"))
else:
if config.topk_logprobs > 0:
warn("For SASA sampling, 'topk_logprobs' is expected to be 0.")
assert logits_output.logits.sasa is not None
assert protein.sasa is not None
valid_ids = (
set(tokenizers.sasa.all_token_ids)
- set(tokenizers.sasa.special_token_ids)
- set(config.invalid_ids)
)
sasa_logits = logits_output.logits.sasa
sasa_value = sample_sasa_logits(
sasa_logits,
protein.sasa,
sampling_track_config=config,
mask_idx=tokenizers.sasa.mask_token_id,
valid_ids=list(valid_ids),
mask_logits_of_invalid_ids=mask_logits_of_invalid_ids,
)
tokens_dir["sasa"] = sasa_value
probs = sasa_logits.softmax(dim=-1)
# Note(tjia): sasa_logits can have -inf because of invalid ids, so
# probs * sasa_logits.log_softmax(-1) is nan. We need to set
# those positions to 0 to get the correct entropy value
entropy = -(torch.nan_to_num(probs * sasa_logits.log_softmax(-1))).sum(-1)
track_sampling_metadata_dir["sasa"] = {"entropy": entropy}
# Sample function and residue annotations separately
config = getattr(sampling_config, "function")
function_logits = getattr(logits_output.logits, "function")
if config is None or function_logits is None:
tokens_dir["function"] = maybe_clone(getattr(protein, "function"))
tokens_dir["residue_annotations"] = maybe_clone(
getattr(protein, "residue_annotations")
)
else:
if config.invalid_ids is not None and len(config.invalid_ids) > 0:
warn("For function sampling, invalid_ids sampling config is not supported.")
sampling_metadata = _sample_function_track(
tokenizers.function,
tokens=getattr(protein, "function"),
logits=function_logits,
sampling_track_config=config,
)
tokens_dir["function"] = sampling_metadata.pop("sampled_tokens") # (L, D)
track_sampling_metadata_dir["function"] = sampling_metadata
sampled_tokens, _ = sample_residue_annotation_logits(
logits=logits_output.residue_annotation_logits # type: ignore
)
tokens_dir["residue_annotations"] = sampled_tokens # (L, MAX_R)
# Format output
forward_and_sample_output_dir = {}
forward_and_sample_output_dir["protein_tensor"] = ESMProteinTensor(**tokens_dir)
for property in [
"entropy",
"prob",
"logprob",
"top_prob",
"topk_logprob",
"topk_tokens",
]:
is_all_none = True
forward_track_data_dir = {}
for track in track_sampling_metadata_dir.keys():
values = track_sampling_metadata_dir[track]
if values is not None and values.get(property, None) is not None:
forward_track_data_dir[track] = values.get(property, None)
is_all_none = False
if not is_all_none:
forward_and_sample_output_dir[property] = ForwardTrackData(
**forward_track_data_dir
)
else:
forward_and_sample_output_dir[property] = None
per_res_embed = (
logits_output.embeddings # type: ignore
if sampling_config.return_per_residue_embeddings
else None
)
mean_embedding = (
# [B, L, D] -> [B, D]
logits_output.embeddings.mean(dim=1) # type: ignore
if sampling_config.return_mean_embedding
else None
)
return ForwardAndSampleOutput(
per_residue_embedding=per_res_embed,
mean_embedding=mean_embedding,
**forward_and_sample_output_dir,
)
def _sample_track(
logits: torch.Tensor,
tokens: torch.Tensor,
sampling_track_config: SamplingTrackConfig,
mask_idx: int,
valid_ids: list[int],
mask_logits_of_invalid_ids: bool = True,
) -> dict[str, torch.Tensor]:
"""Works with inputs that have batch dimension."""
# Sample in all positions
temperature = sampling_track_config.temperature
# We have to trim the logits and sampled tokens at potentially padded slots
# since the logits may be computed with a longer padded batch, while tokens
# are the original input sequence.
sampled_tokens = sample_logits(
logits,
temperature=temperature,
valid_ids=valid_ids,
top_p=sampling_track_config.top_p,
mask_logits_of_invalid_ids=mask_logits_of_invalid_ids,
)
log_probs = logits.log_softmax(-1)
sampling_mask = get_sampling_mask(tokens, sampling_track_config, mask_idx)
sampled_tokens = torch.where(sampling_mask, sampled_tokens, tokens)
return _compute_track_metadata(
sampled_tokens,
log_probs,
sampling_mask,
top_k=sampling_track_config.topk_logprobs,
)
def _sample_function_track(
function_tokenizer: InterProQuantizedTokenizer,
tokens: torch.Tensor,
logits: torch.Tensor,
sampling_track_config: SamplingTrackConfig,
) -> dict[str, torch.Tensor]:
"""Works with inputs that have batch dimension."""
# Do not sample at BOS and EOS tokens
sampling_mask = torch.ones_like(tokens, dtype=torch.bool)[..., 0] # (B, L)
sampling_mask[..., 0] = False
sampling_mask[..., -1] = False
sampled_tokens, logprobs = sample_function_logits(
logits,
function_tokenizer,
top_p=sampling_track_config.top_p,
temperature=sampling_track_config.temperature,
)
if sampling_track_config.only_sample_masked_tokens:
is_mask = torch.all(
tokens == function_tokenizer.mask_token_id, dim=-1
) # (B, L)
sampling_mask = sampling_mask & is_mask
sampled_tokens = torch.where(
sampling_mask[..., None].expand_as(sampled_tokens), sampled_tokens, tokens
) # (B, L, D)
# Set logprobs for non-sampled tokens to 0
logprobs_null = torch.full_like(logprobs, -torch.inf) # (B, L, D, V)
logprobs_null = torch.scatter(
logprobs_null, -1, tokens[..., None], torch.zeros_like(logprobs_null)[..., [0]]
)
logprobs = torch.where(
sampling_mask[..., None, None].expand_as(logprobs), logprobs, logprobs_null
) # (B, L, D, V)
function_metadata = _compute_track_metadata(
sampled_tokens,
logprobs,
sampling_mask,
top_k=sampling_track_config.topk_logprobs,
)
# Consider the entropy of the joint distribution of all function tokens at each position
function_metadata["entropy"] = function_metadata["entropy"].sum(
-1
) # (B, L, D) -> (B, L)
return function_metadata
def _compute_track_metadata(
sampled_tokens: torch.Tensor,
log_probs: torch.Tensor,
sampling_mask: torch.Tensor,
top_k: int,
) -> dict:
"""Works with inputs that have batch dimension."""
probs = torch.exp(log_probs) # (B, L)
entropy = torch.distributions.Categorical(logits=log_probs).entropy() # (B, L)
# Only compute probabilities for sampled tokens
sampled_logprob = torch.zeros_like(sampled_tokens, dtype=log_probs.dtype) # (B, L)
if sampled_tokens.dim() > sampling_mask.dim():
assert sampled_tokens.dim() == 3 # (B, L, D)
assert sampling_mask.dim() == 2 # (B, L)
sampling_mask = sampling_mask[..., None].expand_as(sampled_tokens)
sampled_tokens_valid = sampled_tokens[sampling_mask]
sampled_log_probs_valid = log_probs[sampling_mask, sampled_tokens_valid]
sampled_logprob[sampling_mask] = sampled_log_probs_valid
# Calculate extra metadata
sampled_prob = torch.exp(sampled_logprob)
top_prob = torch.max(probs, dim=-1).values
topk_logprobs, topk_tokens = torch.topk(log_probs, top_k, dim=-1)
topk_logprobs = None if top_k == 0 else topk_logprobs
topk_tokens = None if top_k == 0 else topk_tokens
return {
"entropy": entropy,
"sampled_tokens": sampled_tokens,
"prob": sampled_prob,
"logprob": sampled_logprob,
"top_prob": top_prob,
"topk_logprob": topk_logprobs,
"topk_tokens": topk_tokens,
}