File size: 5,092 Bytes
e4aa3d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import torch
import torchaudio
import torch.nn.functional as F
from typing import Optional, List, Tuple
from tqdm import tqdm
def apply_top_k(logits, top_k):
batch_size, vocab_size = logits.shape
top_k = min(top_k, vocab_size)
top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1)
filtered_logits = torch.full_like(logits, float("-inf"))
batch_indices = torch.arange(batch_size).unsqueeze(-1)
filtered_logits[batch_indices, top_k_indices] = top_k_values
return filtered_logits
def apply_top_p(logits, top_p):
probs = F.softmax(logits, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
batch_size = logits.shape[0]
filtered_logits = logits.clone()
for i in range(batch_size):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
filtered_logits[i, indices_to_remove] = float("-inf")
return filtered_logits
def apply_top_p_optimized(logits, top_p):
probs = F.softmax(logits, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = float("-inf")
return logits
def apply_repetition_penalty_delay_pattern(
logits: torch.Tensor,
prev_tokens: torch.LongTensor,
penalty: float,
):
"""
logits: [B, H, V] or [N, V]
prev_tokens: [B, T, H] or [N, T] or [B, H]
Apply the repetition penalty independently for each H (VQ head).
"""
if penalty == 1.0 or prev_tokens is None:
return logits
vocab_size = logits.size(-1)
# Case 1: regular [N, V] (text layer)
if logits.dim() == 2:
prev_tokens_flat = prev_tokens.reshape(-1)
unique_tokens = torch.unique(prev_tokens_flat)
token_logits = logits[:, unique_tokens]
pos_mask = token_logits > 0
token_logits[pos_mask] /= penalty
token_logits[~pos_mask] *= penalty
logits[:, unique_tokens] = token_logits
return logits
# Case 2: Delay Pattern audio [B, H, V]
assert logits.dim() == 3, "Delay Pattern audio logits must be [B, H, V]"
B, H, V = logits.shape
for h in range(H):
# prev_tokens_h: [B, T] or [B]
prev_tokens_h = prev_tokens[..., h].reshape(-1)
unique_tokens = torch.unique(prev_tokens_h)
if unique_tokens.numel() == 0:
continue
token_logits = logits[:, h, unique_tokens]
pos_mask = token_logits > 0
token_logits[pos_mask] /= penalty
token_logits[~pos_mask] *= penalty
logits[:, h, unique_tokens] = token_logits
return logits
def sample_token(
logits,
prev_tokens: Optional[torch.LongTensor] = None,
repetition_penalty: float = 1.0,
top_p=None,
top_k=None,
do_sample=True,
):
vocab_size = logits.size(-1)
# ===== Repetition Penalty (before reshaping!) =====
if prev_tokens is not None and repetition_penalty != 1.0:
logits = apply_repetition_penalty_delay_pattern(
logits,
prev_tokens,
repetition_penalty,
)
if not do_sample:
return torch.argmax(logits, dim=-1)
# ===== Only flatten after this, for top-k / top-p / multinomial =====
original_shape = logits.shape
reshaped_logits = logits.view(-1, vocab_size)
if top_k is not None and top_k > 0:
reshaped_logits = apply_top_k(reshaped_logits, top_k)
if top_p is not None and top_p < 1.0:
reshaped_logits = apply_top_p_optimized(reshaped_logits, top_p)
probs = F.softmax(reshaped_logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1)
return next_tokens.view(original_shape[:-1])
def find_last_equal_C(tensor, C):
"""
tensor: torch.Tensor of shape [batch_size, seq_len]
C: scalar value to match
Returns: torch.Tensor of shape [batch_size] with last indices
"""
mask = (tensor == C).int() # Shape: [batch_size, seq_len], bool tensor
flipped_mask = mask.flip(dims=[1]) # Flip along sequence dimension
flipped_indices = flipped_mask.argmax(dim=1) # First True in flipped
seq_len = tensor.shape[1]
last_indices = (seq_len - 1) - flipped_indices # Convert to original indices
# Optional: Handle cases with no C (set to -1), though problem assumes existence
actual_values = tensor[torch.arange(tensor.shape[0]), last_indices]
no_match = actual_values != C
last_indices[no_match] = -1
return last_indices
|