Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,095 Bytes
1315cad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from __future__ import annotations
import torch
from .sampler import sample_token
def apply_classifier_guidance(
logits: torch.Tensor,
cfg_active: bool,
scale: float,
top_k: int,
) -> torch.Tensor:
if not cfg_active:
return logits
conditional = logits[0:1]
unconditional = logits[1:2]
cond32 = conditional.to(torch.float32)
uncond32 = unconditional.to(torch.float32)
guided = torch.lerp(uncond32, cond32, scale)
if top_k > 0 and guided.shape[-1] > 0:
k = min(top_k, guided.shape[-1])
threshold = torch.topk(guided, k=k, dim=-1, sorted=False).values[..., -1:]
mask = guided >= threshold
neg_inf = torch.full_like(cond32, float("-inf"))
cond32 = torch.where(mask, cond32, neg_inf)
return cond32.to(conditional.dtype)
def sample_audio_logits(logits: torch.Tensor, temp: float, top_k: int) -> torch.Tensor:
"""Sample a single audio token (shape [1]) from logits."""
return (
sample_token(
logits,
temp=temp,
top_k=top_k,
).view(1)
)
|