File size: 5,092 Bytes
e4aa3d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import torch
import torchaudio
import torch.nn.functional as F
from typing import Optional, List, Tuple
from tqdm import tqdm


def apply_top_k(logits, top_k):
    batch_size, vocab_size = logits.shape
    top_k = min(top_k, vocab_size)
    top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1)
    filtered_logits = torch.full_like(logits, float("-inf"))
    batch_indices = torch.arange(batch_size).unsqueeze(-1)
    filtered_logits[batch_indices, top_k_indices] = top_k_values
    return filtered_logits


def apply_top_p(logits, top_p):
    probs = F.softmax(logits, dim=-1)
    sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = False
    batch_size = logits.shape[0]
    filtered_logits = logits.clone()
    for i in range(batch_size):
        indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
        filtered_logits[i, indices_to_remove] = float("-inf")
    return filtered_logits


def apply_top_p_optimized(logits, top_p):
    probs = F.softmax(logits, dim=-1)
    sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)

    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = False

    indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(
        dim=-1, index=sorted_indices, src=sorted_indices_to_remove
    )

    logits[indices_to_remove] = float("-inf")
    return logits


def apply_repetition_penalty_delay_pattern(
    logits: torch.Tensor,
    prev_tokens: torch.LongTensor,
    penalty: float,
):
    """
    logits: [B, H, V]  or [N, V]
    prev_tokens: [B, T, H] or [N, T] or [B, H]

    Apply the repetition penalty independently for each H (VQ head).
    """
    if penalty == 1.0 or prev_tokens is None:
        return logits

    vocab_size = logits.size(-1)

    # Case 1: regular [N, V] (text layer)
    if logits.dim() == 2:
        prev_tokens_flat = prev_tokens.reshape(-1)
        unique_tokens = torch.unique(prev_tokens_flat)

        token_logits = logits[:, unique_tokens]
        pos_mask = token_logits > 0
        token_logits[pos_mask] /= penalty
        token_logits[~pos_mask] *= penalty
        logits[:, unique_tokens] = token_logits
        return logits

    # Case 2: Delay Pattern audio [B, H, V]
    assert logits.dim() == 3, "Delay Pattern audio logits must be [B, H, V]"
    B, H, V = logits.shape

    for h in range(H):
        # prev_tokens_h: [B, T] or [B]
        prev_tokens_h = prev_tokens[..., h].reshape(-1)
        unique_tokens = torch.unique(prev_tokens_h)

        if unique_tokens.numel() == 0:
            continue

        token_logits = logits[:, h, unique_tokens]
        pos_mask = token_logits > 0
        token_logits[pos_mask] /= penalty
        token_logits[~pos_mask] *= penalty
        logits[:, h, unique_tokens] = token_logits

    return logits


def sample_token(
    logits,
    prev_tokens: Optional[torch.LongTensor] = None,
    repetition_penalty: float = 1.0,
    top_p=None,
    top_k=None,
    do_sample=True,
):
    vocab_size = logits.size(-1)

    # ===== Repetition Penalty (before reshaping!) =====
    if prev_tokens is not None and repetition_penalty != 1.0:
        logits = apply_repetition_penalty_delay_pattern(
            logits,
            prev_tokens,
            repetition_penalty,
        )

    if not do_sample:
        return torch.argmax(logits, dim=-1)

    # ===== Only flatten after this, for top-k / top-p / multinomial =====
    original_shape = logits.shape
    reshaped_logits = logits.view(-1, vocab_size)

    if top_k is not None and top_k > 0:
        reshaped_logits = apply_top_k(reshaped_logits, top_k)

    if top_p is not None and top_p < 1.0:
        reshaped_logits = apply_top_p_optimized(reshaped_logits, top_p)

    probs = F.softmax(reshaped_logits, dim=-1)
    next_tokens = torch.multinomial(probs, num_samples=1)

    return next_tokens.view(original_shape[:-1])


def find_last_equal_C(tensor, C):
    """
    tensor: torch.Tensor of shape [batch_size, seq_len]
    C: scalar value to match
    Returns: torch.Tensor of shape [batch_size] with last indices
    """
    mask = (tensor == C).int()  # Shape: [batch_size, seq_len], bool tensor
    flipped_mask = mask.flip(dims=[1])  # Flip along sequence dimension
    flipped_indices = flipped_mask.argmax(dim=1)  # First True in flipped
    seq_len = tensor.shape[1]
    last_indices = (seq_len - 1) - flipped_indices  # Convert to original indices

    # Optional: Handle cases with no C (set to -1), though problem assumes existence
    actual_values = tensor[torch.arange(tensor.shape[0]), last_indices]
    no_match = actual_values != C
    last_indices[no_match] = -1

    return last_indices