object-assembler / code /cube3d /inference /logits_postprocesses.py
0xZohar's picture
Add missing code/cube3d/inference/logits_postprocesses.py
cfbc796 verified
import torch
import torch.nn.functional as F
def top_p_filtering(logits, top_p: float = 1.0):
"""
Filter a distribution of logits using top-p filtering.
The input logits tensor is modified in-place.
Args:
logits (torch.Tensor): A tensor of logits to be filtered. Expected shape is [..., vocab_size].
top_p (float, optional): The cumulative probability threshold for top-p sampling.
If < 1.0, only keep the smallest set of tokens whose
cumulative probability does not exceed this threshold.
Returns:
torch.Tensor: logits where values outside the top-p threshold are set to -∞.
"""
if top_p < 1.0:
sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True)
sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum(dim=-1) > top_p
sorted_idx_to_remove[..., 0] = False
idx_to_remove = sorted_idx_to_remove.scatter(
-1, sorted_idx, sorted_idx_to_remove
)
logits.masked_fill_(idx_to_remove, -torch.inf)
return logits
def process_logits(
logits,
top_p: float = None,
):
"""
Process logits by optionally applying nucleus (top-p) filtering and token selection.
If `top_p` is None, the token with the highest probability (argmax) is selected.
If `top_p` is provided, smallest set of tokens with cumulative probability ≥ top_p are kept, then softmax is applied to obtain
probabilities. A token is sampled from this filtered distribution using `torch.multinomial`.
Args:
logits (torch.Tensor): A tensor of logits to process.
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
Returns:
torch.Tensor: selected token index.
"""
if top_p is None:
#import ipdb; ipdb.set_trace()
next_id = torch.argmax(logits, dim=-1, keepdim=True)
else:
logits = top_p_filtering(logits, top_p=0.9)
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1, replacement=True)
return next_id
def process_logits_assembly(
logits,
#tokens_num: int = 19,
top_p: float = None,
pos_id: int = 0,
stride: int = 0
):
"""
Process logits by optionally applying nucleus (top-p) filtering and token selection.
If `top_p` is None, the token with the highest probability (argmax) is selected.
If `top_p` is provided, smallest set of tokens with cumulative probability ≥ top_p are kept, then softmax is applied to obtain
probabilities. A token is sampled from this filtered distribution using `torch.multinomial`.
Args:
logits (torch.Tensor): A tensor of logits to process.
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
Returns:
torch.Tensor: selected token index.
"""
dat_num = 604
x_num = 213
y_num = 217
z_num = 529
rot_num = 24
# x = x_num
# xy = x_num + y_num + rot_num
# xyz = x_num + y_num + z_num + rot_num
if top_p is None:
if pos_id % stride==0 and stride>3:
next_id = logits[:, :dat_num+1].argmax(dim=-1) # [B]
elif pos_id % stride==1 and stride>4:
next_id = logits[:, :rot_num+1].argmax(dim=-1)
elif pos_id % stride==(stride-3):
next_id = logits[:, :y_num+1].argmax(dim=-1)
elif pos_id % stride==(stride-2):
next_id = logits[:, :x_num+1].argmax(dim=-1)
elif pos_id % stride==(stride-1):
next_id = logits[:, :z_num+1].argmax(dim=-1)
else:
if pos_id % stride == 0 and stride > 3:
logits = logits[:, :dat_num+1]
elif pos_id % stride == 1 and stride > 4:
logits = logits[:, :rot_num+1]
elif pos_id % stride == (stride-3):
logits = logits[:, :y_num+1]
elif pos_id % stride == (stride-2):
logits = logits[:, :x_num+1]
elif pos_id % stride == (stride-1):
logits = logits[:, :z_num+1]
logits = top_p_filtering(logits, top_p=top_p)
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1, replacement=True)
return next_id