Spaces:
Paused
Paused
| import torch | |
| import torch.nn.functional as F | |
| def top_p_filtering(logits, top_p: float = 1.0): | |
| """ | |
| Filter a distribution of logits using top-p filtering. | |
| The input logits tensor is modified in-place. | |
| Args: | |
| logits (torch.Tensor): A tensor of logits to be filtered. Expected shape is [..., vocab_size]. | |
| top_p (float, optional): The cumulative probability threshold for top-p sampling. | |
| If < 1.0, only keep the smallest set of tokens whose | |
| cumulative probability does not exceed this threshold. | |
| Returns: | |
| torch.Tensor: logits where values outside the top-p threshold are set to -∞. | |
| """ | |
| if top_p < 1.0: | |
| sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True) | |
| sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum(dim=-1) > top_p | |
| sorted_idx_to_remove[..., 0] = False | |
| idx_to_remove = sorted_idx_to_remove.scatter( | |
| -1, sorted_idx, sorted_idx_to_remove | |
| ) | |
| logits.masked_fill_(idx_to_remove, -torch.inf) | |
| return logits | |
| def process_logits( | |
| logits, | |
| top_p: float = None, | |
| ): | |
| """ | |
| Process logits by optionally applying nucleus (top-p) filtering and token selection. | |
| If `top_p` is None, the token with the highest probability (argmax) is selected. | |
| If `top_p` is provided, smallest set of tokens with cumulative probability ≥ top_p are kept, then softmax is applied to obtain | |
| probabilities. A token is sampled from this filtered distribution using `torch.multinomial`. | |
| Args: | |
| logits (torch.Tensor): A tensor of logits to process. | |
| top_p (float, optional): The cumulative probability threshold for nucleus sampling. | |
| If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation). | |
| Returns: | |
| torch.Tensor: selected token index. | |
| """ | |
| if top_p is None: | |
| #import ipdb; ipdb.set_trace() | |
| next_id = torch.argmax(logits, dim=-1, keepdim=True) | |
| else: | |
| logits = top_p_filtering(logits, top_p=0.9) | |
| probs = F.softmax(logits, dim=-1) | |
| next_id = torch.multinomial(probs, num_samples=1, replacement=True) | |
| return next_id | |
| def process_logits_assembly( | |
| logits, | |
| #tokens_num: int = 19, | |
| top_p: float = None, | |
| pos_id: int = 0, | |
| stride: int = 0 | |
| ): | |
| """ | |
| Process logits by optionally applying nucleus (top-p) filtering and token selection. | |
| If `top_p` is None, the token with the highest probability (argmax) is selected. | |
| If `top_p` is provided, smallest set of tokens with cumulative probability ≥ top_p are kept, then softmax is applied to obtain | |
| probabilities. A token is sampled from this filtered distribution using `torch.multinomial`. | |
| Args: | |
| logits (torch.Tensor): A tensor of logits to process. | |
| top_p (float, optional): The cumulative probability threshold for nucleus sampling. | |
| If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation). | |
| Returns: | |
| torch.Tensor: selected token index. | |
| """ | |
| dat_num = 604 | |
| x_num = 213 | |
| y_num = 217 | |
| z_num = 529 | |
| rot_num = 24 | |
| # x = x_num | |
| # xy = x_num + y_num + rot_num | |
| # xyz = x_num + y_num + z_num + rot_num | |
| if top_p is None: | |
| if pos_id % stride==0 and stride>3: | |
| next_id = logits[:, :dat_num+1].argmax(dim=-1) # [B] | |
| elif pos_id % stride==1 and stride>4: | |
| next_id = logits[:, :rot_num+1].argmax(dim=-1) | |
| elif pos_id % stride==(stride-3): | |
| next_id = logits[:, :y_num+1].argmax(dim=-1) | |
| elif pos_id % stride==(stride-2): | |
| next_id = logits[:, :x_num+1].argmax(dim=-1) | |
| elif pos_id % stride==(stride-1): | |
| next_id = logits[:, :z_num+1].argmax(dim=-1) | |
| else: | |
| if pos_id % stride == 0 and stride > 3: | |
| logits = logits[:, :dat_num+1] | |
| elif pos_id % stride == 1 and stride > 4: | |
| logits = logits[:, :rot_num+1] | |
| elif pos_id % stride == (stride-3): | |
| logits = logits[:, :y_num+1] | |
| elif pos_id % stride == (stride-2): | |
| logits = logits[:, :x_num+1] | |
| elif pos_id % stride == (stride-1): | |
| logits = logits[:, :z_num+1] | |
| logits = top_p_filtering(logits, top_p=top_p) | |
| probs = F.softmax(logits, dim=-1) | |
| next_id = torch.multinomial(probs, num_samples=1, replacement=True) | |
| return next_id |