File size: 5,550 Bytes
c9f87fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import math
import torch
import torch.nn.functional as F

from typing import Iterable, Union, Optional
import numpy as np
from numpy.random import RandomState

from .mask import cosine_schedule, format_seed

################################################################################
# Utilities for sampling from trained TRIA model
################################################################################


def top_p_top_k(
    logits: torch.Tensor, 
    top_p: float = None, 
    top_k: int = None,
):
    """
    Adapted from `vampnet.modules.transformer.sample_from_logits` by Hugo Flores
    Garcia. See: https://github.com/hugofloresgarcia/vampnet/
    
    Parameters
    ----------
    logits : torch.Tensor
        Shape (..., n_classes)
    """
    logits = logits.clone()
    n_classes = logits.shape[-1]

    # Mask logits outside top-k by setting to -inf
    if top_k is not None and 0 < top_k < n_classes:
        thresh = logits.topk(top_k, dim=-1).values[..., -1:]  # (..., 1)
        logits[logits < thresh] = float("-inf")

    # Mask logits outside top-p by setting to -inf
    if top_p is not None and 0.0 < top_p < 1.0:
        # Sort descending
        sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True)   # (..., n_classes)
        sorted_probs = F.softmax(sorted_logits, dim=-1)                    # (..., n_classes)
        cumsum = sorted_probs.cumsum(dim=-1)                               # (..., n_classes)

        # Keep at least one logit
        to_remove = cumsum > top_p
        to_remove[..., 0] = False
        remove_idx = torch.zeros_like(to_remove).scatter(-1, sorted_idx, to_remove)
        logits[remove_idx] = float("-inf")
        
    return logits


def sample(
    logits: torch.Tensor,
    temp: float,
    argmax: bool = False,
):
    """
    Adapted from `vampnet.modules.transformer.sample_from_logits` by Hugo Flores
    Garcia. See: https://github.com/hugofloresgarcia/vampnet/
    
    Parameters
    ----------
    logits : torch.Tensor
        Shape (..., n_classes)

    Returns
    -------
    torch.Tensor
        Sampled tokens, shape of `logits` with trailing `n_classes` dimension
        removed
    torch.Tensor
        Probabilities of sampled tokens, shape of `logits` with trailing 
        `n_classes` dimension removed
    """
    if temp <= 0:
        argmax = True
        temp = 1.0

    if argmax:
        sampled = logits.argmax(dim=-1)
        probs = F.softmax(
            logits, dim=-1
        ).take_along_dim(sampled.unsqueeze(-1), dim=-1).squeeze(-1)
        return sampled, probs

    probs = F.softmax(logits / temp, dim=-1)
    flat = probs.reshape(-1, probs.shape[-1])
    draws = torch.multinomial(flat, 1).squeeze(-1)
    sampled = draws.view(*probs.shape[:-1])
    chosen = probs.take_along_dim(sampled.unsqueeze(-1), dim=-1).squeeze(-1)
    return sampled, chosen


def mask_by_confidence(
    probs: torch.Tensor,
    n: torch.Tensor,
    temp: float,
    causal_bias: float,
    state: Iterable[RandomState],
    eligible: Optional[torch.Tensor] = None,
):
    """
    Re-mask predicted tokens in a single codebook such that `n` previously-
    masked tokens are left unmasked, using confidence (probability assigned to 
    tokens during sampling) to select which tokens remain. This confidence can 
    be mediated by random noise and a bias to unmask early (leftward) positions 
    first.

    Parameters
    ----------
    probs : torch.Tensor
        Probabilities assigned to sampled tokens, shape (n_batch, n_frames)
    n : torch.Tensor
        Target number of unmasked tokens, shape (n_batch,)
    temp : float
        Mask temperature, corresponding to randomness in unmasking process
    causal_bias : float
        Bias towards unmasking early (leftward) token positions first; typically 
        in (0, 1]. Note that large values of `temp` can effectively "wash out"
        this causal bias
    state : Iterable[RandomState]
        Random seeds for reproducibility
    eligible : torch.Tensor
        Optional indicator for positions eligible for unmasking, shape (n_batch, n_frames)
    
    """
    
    n_batch, n_frames = probs.shape
    device = probs.device

    if eligible is None:
        eligible = torch.isfinite(probs) & (probs > 0)
    else:
        eligible = eligible.to(torch.bool)

    # Masked token count and target
    n_masked = eligible.long().sum(dim=-1)
    n_unmask = (n_masked - n).clamp_min(0)

    # Gumbel noise to introduce randomness into unmasking
    u = torch.stack([
        torch.from_numpy(s.uniform(1e-6, 1 - 1e-6, n_frames)) for s in state
    ], dim=0).to(probs)
    gumbel = -torch.log(-torch.log(u))

    # Log-confidences + noise
    s = probs.clamp_min(1e-12)
    confs = torch.log(s) + temp * gumbel

    # Optional causal bias in log-domain
    if causal_bias > 0:
        frame_relpos = (1 - (torch.arange(n_frames, device=device, dtype=confs.dtype) + 1) / n_frames).view(1, -1)
        confs = confs + causal_bias * frame_relpos

    # Only eligible positions can be chosen
    confs_masked = confs.masked_fill(~eligible, float("-inf"))
    sorted_vals, sorted_idx = confs_masked.sort(dim=-1, descending=True)
    rank = torch.arange(n_frames, device=device).view(1, n_frames).expand_as(confs_masked)
    k = n_unmask.view(n_batch, 1)
    pick_sorted = rank < k
    pick = torch.zeros_like(pick_sorted, dtype=torch.bool).scatter(-1, sorted_idx, pick_sorted)

    # Return tokens_mask semantics (True = unmasked/keep)
    mask = ~(eligible & (~pick))
    return mask