Spaces:
Sleeping
Sleeping
| # Copyright (c) Kyutai, all rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| def multinomial( | |
| input: torch.Tensor, num_samples: int, replacement=False, *, generator=None | |
| ): | |
| """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. | |
| Args: | |
| input (torch.Tensor): The input tensor containing probabilities. | |
| num_samples (int): Number of samples to draw. | |
| replacement (bool): Whether to draw with replacement or not. | |
| Keywords args: | |
| generator (torch.Generator): A pseudorandom number generator for sampling. | |
| Returns: | |
| torch.Tensor: Last dimension contains num_samples indices | |
| sampled from the multinomial probability distribution | |
| located in the last dimension of tensor input. | |
| """ | |
| input_ = input.reshape(-1, input.shape[-1]) | |
| # We should probably be able to remove this once the following PR has landed: | |
| # https://github.com/pytorch/pytorch/pull/134818/files | |
| # In the meantime, we specialize the case no-replacement, nsamples=1 so as not | |
| # to have a synchronization point. | |
| if replacement or num_samples != 1: | |
| output_ = torch.multinomial( | |
| input_, | |
| num_samples=num_samples, | |
| replacement=replacement, | |
| generator=generator, | |
| ) | |
| else: | |
| q = torch.empty_like(input_).exponential_(1, generator=generator) | |
| q = input_ / q | |
| output_ = q.argmax(dim=-1, keepdim=True) | |
| output = output_.reshape(*list(input.shape[:-1]), -1) | |
| return output | |
| def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: | |
| """Sample next token from top K values along the last dimension of the input probs tensor. | |
| Args: | |
| probs (torch.Tensor): Input probabilities with token candidates on the last dimension. | |
| k (int): The k in “top-k”. | |
| Returns: | |
| torch.Tensor: Sampled tokens. | |
| """ | |
| probs, indices = torch.topk(probs, k, dim=-1) | |
| next_token = multinomial(probs, num_samples=1) | |
| next_token = indices.gather(-1, next_token) | |
| return next_token | |
| def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: | |
| """Sample next token from top P probabilities along the last dimension of the input probs tensor. | |
| Args: | |
| probs (torch.Tensor): Input probabilities with token candidates on the last dimension. | |
| p (int): The p in “top-p”. | |
| Returns: | |
| torch.Tensor: Sampled tokens. | |
| """ | |
| probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | |
| probs_sum = torch.cumsum(probs_sort, dim=-1) | |
| mask = probs_sum - probs_sort > p | |
| probs_sort *= (~mask).float() | |
| probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | |
| next_token = multinomial(probs_sort, num_samples=1) | |
| next_token = torch.gather(probs_idx, -1, next_token) | |
| return next_token | |
| def sample_token( | |
| logits: torch.Tensor, | |
| use_sampling: bool = False, | |
| temp: float = 1.0, | |
| top_k: int = 0, | |
| top_p: float = 0.0, | |
| ) -> torch.Tensor: | |
| """Given logits of shape [*, Card], returns a LongTensor of shape [*].""" | |
| # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error. | |
| if use_sampling and temp > 0.0: | |
| probs = torch.softmax(logits / temp, dim=-1) | |
| if top_p > 0.0: | |
| next_token = sample_top_p(probs, p=top_p) | |
| elif top_k > 0: | |
| next_token = sample_top_k(probs, k=top_k) | |
| else: | |
| next_token = multinomial(probs, num_samples=1) | |
| else: | |
| next_token = torch.argmax(logits, dim=-1, keepdim=True) | |
| assert next_token.shape[-1] == 1 | |
| return next_token[..., 0] | |
| if __name__ == "__main__": | |
| torch.manual_seed(1234) | |
| device = "cpu" | |
| if torch.cuda.is_available(): | |
| torch.backends.cuda.matmul.allow_tf32 = False | |
| torch.backends.cudnn.allow_tf32 = False | |
| device = "cuda:0" | |
| ps = torch.tensor([5.0, 2.0, 12.0, 6.0, 8.0, 1.0, 0.0, 4.0], device=device) | |
| cnts = torch.zeros(ps.shape, dtype=torch.long, device=device) | |
| total_samples = 1000 | |
| for _ in range(total_samples): | |
| vs = multinomial(ps, num_samples=1, replacement=False) | |
| cnts[vs] += 1 | |
| diff = cnts / cnts.sum() - ps / ps.sum() | |
| max_diff = diff.abs().max().cpu().item() | |
| print(ps / ps.sum()) | |
| print(cnts / cnts.sum()) | |
| assert max_diff < 1.5e-2 | |