# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from typing import List, Optional import torch import torch.nn.functional as F def get_line_info(): from inspect import getframeinfo, currentframe caller = getframeinfo(currentframe().f_back) return f"{caller.filename}:{caller.lineno}" def assert_shape(tensor, *shape_args): # This only returns the unbound shapes source = f'{get_line_info()}:' assert len(tensor.shape) == len(shape_args), f"{source} should have {len(shape_args)} dimensions, actually has shape {tensor.shape}" unbound = [] for i, ii in enumerate(shape_args): if ii == -1: unbound.append(tensor.shape[i]) else: assert ( tensor.shape[i] == ii ), f"{source} shape should be {shape_args}, actually {tensor.shape}" if len(unbound) == 1: return unbound[0] return unbound def assert_probs(tensor: torch.FloatTensor): """ Assert that [*,K] input tensor is valid probs in final dim. """ assert is_prob_tensor(tensor) def assert_logprobs(tensor: torch.FloatTensor): """ Assert that [*,K] input tensor is valid logprobs in final dim. """ try: assert_probs(tensor.exp()) except AssertionError: raise AssertionError('Not logprobs, perhaps you have logits and need to apply F.log_softmax?') def is_1hot_tensor(x1h): return x1h.min() == 0 and x1h.max() == 1 and \ x1h.shape[-1] > 1 and (x1h.sum(-1) == 1).all() def is_prob_tensor(p, atol=1e-3): return ( p.min() >= 0 and p.max() <= 1 and torch.isclose( p.sum(axis=-1), torch.ones_like(p.sum(axis=-1)), atol=atol).all() ).item() def add_eos_bos( seq1h, bos_idx: Optional[int] = None, eos_idx: Optional[int] = None, ): """ Helper for (possibly) prepending bos/cls and appending eos tokens. """ B, L, K = seq1h.shape to_concat = [] if bos_idx is not None: to_concat.append(F.one_hot(torch.full([B,1], bos_idx), K).to(seq1h)) to_concat.append(seq1h) if eos_idx is not None: to_concat.append(F.one_hot(torch.full([B,1], eos_idx), K).to(seq1h)) seq1h = torch.cat(to_concat, axis=1) return seq1h