Spaces:
Sleeping
Sleeping
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from typing import Iterable, Union, Optional | |
| import numpy as np | |
| from numpy.random import RandomState | |
| from .mask import cosine_schedule, format_seed | |
| ################################################################################ | |
| # Utilities for sampling from trained TRIA model | |
| ################################################################################ | |
| def top_p_top_k( | |
| logits: torch.Tensor, | |
| top_p: float = None, | |
| top_k: int = None, | |
| ): | |
| """ | |
| Adapted from `vampnet.modules.transformer.sample_from_logits` by Hugo Flores | |
| Garcia. See: https://github.com/hugofloresgarcia/vampnet/ | |
| Parameters | |
| ---------- | |
| logits : torch.Tensor | |
| Shape (..., n_classes) | |
| """ | |
| logits = logits.clone() | |
| n_classes = logits.shape[-1] | |
| # Mask logits outside top-k by setting to -inf | |
| if top_k is not None and 0 < top_k < n_classes: | |
| thresh = logits.topk(top_k, dim=-1).values[..., -1:] # (..., 1) | |
| logits[logits < thresh] = float("-inf") | |
| # Mask logits outside top-p by setting to -inf | |
| if top_p is not None and 0.0 < top_p < 1.0: | |
| # Sort descending | |
| sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True) # (..., n_classes) | |
| sorted_probs = F.softmax(sorted_logits, dim=-1) # (..., n_classes) | |
| cumsum = sorted_probs.cumsum(dim=-1) # (..., n_classes) | |
| # Keep at least one logit | |
| to_remove = cumsum > top_p | |
| to_remove[..., 0] = False | |
| remove_idx = torch.zeros_like(to_remove).scatter(-1, sorted_idx, to_remove) | |
| logits[remove_idx] = float("-inf") | |
| return logits | |
| def sample( | |
| logits: torch.Tensor, | |
| temp: float, | |
| argmax: bool = False, | |
| ): | |
| """ | |
| Adapted from `vampnet.modules.transformer.sample_from_logits` by Hugo Flores | |
| Garcia. See: https://github.com/hugofloresgarcia/vampnet/ | |
| Parameters | |
| ---------- | |
| logits : torch.Tensor | |
| Shape (..., n_classes) | |
| Returns | |
| ------- | |
| torch.Tensor | |
| Sampled tokens, shape of `logits` with trailing `n_classes` dimension | |
| removed | |
| torch.Tensor | |
| Probabilities of sampled tokens, shape of `logits` with trailing | |
| `n_classes` dimension removed | |
| """ | |
| if temp <= 0: | |
| argmax = True | |
| temp = 1.0 | |
| if argmax: | |
| sampled = logits.argmax(dim=-1) | |
| probs = F.softmax( | |
| logits, dim=-1 | |
| ).take_along_dim(sampled.unsqueeze(-1), dim=-1).squeeze(-1) | |
| return sampled, probs | |
| probs = F.softmax(logits / temp, dim=-1) | |
| flat = probs.reshape(-1, probs.shape[-1]) | |
| draws = torch.multinomial(flat, 1).squeeze(-1) | |
| sampled = draws.view(*probs.shape[:-1]) | |
| chosen = probs.take_along_dim(sampled.unsqueeze(-1), dim=-1).squeeze(-1) | |
| return sampled, chosen | |
| def mask_by_confidence( | |
| probs: torch.Tensor, | |
| n: torch.Tensor, | |
| temp: float, | |
| causal_bias: float, | |
| state: Iterable[RandomState], | |
| eligible: Optional[torch.Tensor] = None, | |
| ): | |
| """ | |
| Re-mask predicted tokens in a single codebook such that `n` previously- | |
| masked tokens are left unmasked, using confidence (probability assigned to | |
| tokens during sampling) to select which tokens remain. This confidence can | |
| be mediated by random noise and a bias to unmask early (leftward) positions | |
| first. | |
| Parameters | |
| ---------- | |
| probs : torch.Tensor | |
| Probabilities assigned to sampled tokens, shape (n_batch, n_frames) | |
| n : torch.Tensor | |
| Target number of unmasked tokens, shape (n_batch,) | |
| temp : float | |
| Mask temperature, corresponding to randomness in unmasking process | |
| causal_bias : float | |
| Bias towards unmasking early (leftward) token positions first; typically | |
| in (0, 1]. Note that large values of `temp` can effectively "wash out" | |
| this causal bias | |
| state : Iterable[RandomState] | |
| Random seeds for reproducibility | |
| eligible : torch.Tensor | |
| Optional indicator for positions eligible for unmasking, shape (n_batch, n_frames) | |
| """ | |
| n_batch, n_frames = probs.shape | |
| device = probs.device | |
| if eligible is None: | |
| eligible = torch.isfinite(probs) & (probs > 0) | |
| else: | |
| eligible = eligible.to(torch.bool) | |
| # Masked token count and target | |
| n_masked = eligible.long().sum(dim=-1) | |
| n_unmask = (n_masked - n).clamp_min(0) | |
| # Gumbel noise to introduce randomness into unmasking | |
| u = torch.stack([ | |
| torch.from_numpy(s.uniform(1e-6, 1 - 1e-6, n_frames)) for s in state | |
| ], dim=0).to(probs) | |
| gumbel = -torch.log(-torch.log(u)) | |
| # Log-confidences + noise | |
| s = probs.clamp_min(1e-12) | |
| confs = torch.log(s) + temp * gumbel | |
| # Optional causal bias in log-domain | |
| if causal_bias > 0: | |
| frame_relpos = (1 - (torch.arange(n_frames, device=device, dtype=confs.dtype) + 1) / n_frames).view(1, -1) | |
| confs = confs + causal_bias * frame_relpos | |
| # Only eligible positions can be chosen | |
| confs_masked = confs.masked_fill(~eligible, float("-inf")) | |
| sorted_vals, sorted_idx = confs_masked.sort(dim=-1, descending=True) | |
| rank = torch.arange(n_frames, device=device).view(1, n_frames).expand_as(confs_masked) | |
| k = n_unmask.view(n_batch, 1) | |
| pick_sorted = rank < k | |
| pick = torch.zeros_like(pick_sorted, dtype=torch.bool).scatter(-1, sorted_idx, pick_sorted) | |
| # Return tokens_mask semantics (True = unmasked/keep) | |
| mask = ~(eligible & (~pick)) | |
| return mask | |