TRIA / tria /model /sample.py
saumyap29's picture
initial commit
c9f87fa
import math
import torch
import torch.nn.functional as F
from typing import Iterable, Union, Optional
import numpy as np
from numpy.random import RandomState
from .mask import cosine_schedule, format_seed
################################################################################
# Utilities for sampling from trained TRIA model
################################################################################
def top_p_top_k(
logits: torch.Tensor,
top_p: float = None,
top_k: int = None,
):
"""
Adapted from `vampnet.modules.transformer.sample_from_logits` by Hugo Flores
Garcia. See: https://github.com/hugofloresgarcia/vampnet/
Parameters
----------
logits : torch.Tensor
Shape (..., n_classes)
"""
logits = logits.clone()
n_classes = logits.shape[-1]
# Mask logits outside top-k by setting to -inf
if top_k is not None and 0 < top_k < n_classes:
thresh = logits.topk(top_k, dim=-1).values[..., -1:] # (..., 1)
logits[logits < thresh] = float("-inf")
# Mask logits outside top-p by setting to -inf
if top_p is not None and 0.0 < top_p < 1.0:
# Sort descending
sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True) # (..., n_classes)
sorted_probs = F.softmax(sorted_logits, dim=-1) # (..., n_classes)
cumsum = sorted_probs.cumsum(dim=-1) # (..., n_classes)
# Keep at least one logit
to_remove = cumsum > top_p
to_remove[..., 0] = False
remove_idx = torch.zeros_like(to_remove).scatter(-1, sorted_idx, to_remove)
logits[remove_idx] = float("-inf")
return logits
def sample(
logits: torch.Tensor,
temp: float,
argmax: bool = False,
):
"""
Adapted from `vampnet.modules.transformer.sample_from_logits` by Hugo Flores
Garcia. See: https://github.com/hugofloresgarcia/vampnet/
Parameters
----------
logits : torch.Tensor
Shape (..., n_classes)
Returns
-------
torch.Tensor
Sampled tokens, shape of `logits` with trailing `n_classes` dimension
removed
torch.Tensor
Probabilities of sampled tokens, shape of `logits` with trailing
`n_classes` dimension removed
"""
if temp <= 0:
argmax = True
temp = 1.0
if argmax:
sampled = logits.argmax(dim=-1)
probs = F.softmax(
logits, dim=-1
).take_along_dim(sampled.unsqueeze(-1), dim=-1).squeeze(-1)
return sampled, probs
probs = F.softmax(logits / temp, dim=-1)
flat = probs.reshape(-1, probs.shape[-1])
draws = torch.multinomial(flat, 1).squeeze(-1)
sampled = draws.view(*probs.shape[:-1])
chosen = probs.take_along_dim(sampled.unsqueeze(-1), dim=-1).squeeze(-1)
return sampled, chosen
def mask_by_confidence(
probs: torch.Tensor,
n: torch.Tensor,
temp: float,
causal_bias: float,
state: Iterable[RandomState],
eligible: Optional[torch.Tensor] = None,
):
"""
Re-mask predicted tokens in a single codebook such that `n` previously-
masked tokens are left unmasked, using confidence (probability assigned to
tokens during sampling) to select which tokens remain. This confidence can
be mediated by random noise and a bias to unmask early (leftward) positions
first.
Parameters
----------
probs : torch.Tensor
Probabilities assigned to sampled tokens, shape (n_batch, n_frames)
n : torch.Tensor
Target number of unmasked tokens, shape (n_batch,)
temp : float
Mask temperature, corresponding to randomness in unmasking process
causal_bias : float
Bias towards unmasking early (leftward) token positions first; typically
in (0, 1]. Note that large values of `temp` can effectively "wash out"
this causal bias
state : Iterable[RandomState]
Random seeds for reproducibility
eligible : torch.Tensor
Optional indicator for positions eligible for unmasking, shape (n_batch, n_frames)
"""
n_batch, n_frames = probs.shape
device = probs.device
if eligible is None:
eligible = torch.isfinite(probs) & (probs > 0)
else:
eligible = eligible.to(torch.bool)
# Masked token count and target
n_masked = eligible.long().sum(dim=-1)
n_unmask = (n_masked - n).clamp_min(0)
# Gumbel noise to introduce randomness into unmasking
u = torch.stack([
torch.from_numpy(s.uniform(1e-6, 1 - 1e-6, n_frames)) for s in state
], dim=0).to(probs)
gumbel = -torch.log(-torch.log(u))
# Log-confidences + noise
s = probs.clamp_min(1e-12)
confs = torch.log(s) + temp * gumbel
# Optional causal bias in log-domain
if causal_bias > 0:
frame_relpos = (1 - (torch.arange(n_frames, device=device, dtype=confs.dtype) + 1) / n_frames).view(1, -1)
confs = confs + causal_bias * frame_relpos
# Only eligible positions can be chosen
confs_masked = confs.masked_fill(~eligible, float("-inf"))
sorted_vals, sorted_idx = confs_masked.sort(dim=-1, descending=True)
rank = torch.arange(n_frames, device=device).view(1, n_frames).expand_as(confs_masked)
k = n_unmask.view(n_batch, 1)
pick_sorted = rank < k
pick = torch.zeros_like(pick_sorted, dtype=torch.bool).scatter(-1, sorted_idx, pick_sorted)
# Return tokens_mask semantics (True = unmasked/keep)
mask = ~(eligible & (~pick))
return mask