File size: 1,095 Bytes
aa16b75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from __future__ import annotations

import torch

from .sampler import sample_token


def apply_classifier_guidance(
    logits: torch.Tensor,
    cfg_active: bool,
    scale: float,
    top_k: int,
) -> torch.Tensor:
    if not cfg_active:
        return logits
    conditional = logits[0:1]
    unconditional = logits[1:2]
    cond32 = conditional.to(torch.float32)
    uncond32 = unconditional.to(torch.float32)
    guided = torch.lerp(uncond32, cond32, scale)
    if top_k > 0 and guided.shape[-1] > 0:
        k = min(top_k, guided.shape[-1])
        threshold = torch.topk(guided, k=k, dim=-1, sorted=False).values[..., -1:]
        mask = guided >= threshold
        neg_inf = torch.full_like(cond32, float("-inf"))
        cond32 = torch.where(mask, cond32, neg_inf)
    return cond32.to(conditional.dtype)


def sample_audio_logits(logits: torch.Tensor, temp: float, top_k: int) -> torch.Tensor:
    """Sample a single audio token (shape [1]) from logits."""
    return (
        sample_token(
            logits,
            temp=temp,
            top_k=top_k,
        ).view(1)
    )