Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This source code is licensed under the Chameleon License found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| class TokenSelector: | |
| def __call__( | |
| self, input_ids: torch.LongTensor, probs: torch.FloatTensor | |
| ) -> torch.FloatTensor: | |
| # input_ids.shape=[batch, seq_len] | |
| # probs.shape=[batch, vocab] | |
| ... | |
| class ArgmaxTokenSelector(TokenSelector): | |
| def __call__( | |
| self, _: torch.LongTensor, probs: torch.FloatTensor | |
| ) -> torch.LongTensor: | |
| # probs.shape=[batch, vocab] | |
| return probs.argmax(dim=1) | |
| class MultinomialTokenSelector(TokenSelector): | |
| def __call__( | |
| self, _: torch.LongTensor, probs: torch.FloatTensor | |
| ) -> torch.LongTensor: | |
| # probs.shape=[batch, vocab] | |
| return probs.multinomial(num_samples=1).squeeze(1) | |
| class ReplicatedInputTokenSelector(TokenSelector): | |
| def __init__(self, token_selector: TokenSelector, n: int): | |
| self.token_selector = token_selector | |
| self.n = n | |
| def __call__( | |
| self, input_ids: torch.LongTensor, probs: torch.FloatTensor | |
| ) -> torch.LongTensor: | |
| # input_ids.shape=[n*batch, seq_len] | |
| # probs.shape=[n*batch, vocab] | |
| primary_input_ids = torch.chunk(input_ids, chunks=self.n, dim=0)[0] | |
| primary_probs = torch.chunk(probs, chunks=self.n, dim=0)[0] | |
| tokens = self.token_selector(primary_input_ids, primary_probs) | |
| return tokens.repeat(self.n) | |